From 7b49ed1bce227454a7f8b30a20507acb47253e0b Mon Sep 17 00:00:00 2001 From: Pxl Date: Fri, 26 Jul 2024 14:57:24 +0800 Subject: [PATCH] [Improvement](runtime-filter) use shared ptr to instead object pool to store runtime filters (#38085) ## Proposed changes use shared ptr to instead object pool to store runtime filters --- be/src/exprs/create_predicate_function.h | 4 +- be/src/exprs/minmax_predicate.h | 60 ++---- be/src/exprs/runtime_filter.cpp | 200 ++++++++---------- be/src/exprs/runtime_filter.h | 32 ++- be/src/exprs/runtime_filter_slots.h | 25 +-- be/src/exprs/runtime_filter_slots_cross.h | 15 +- .../common/runtime_filter_consumer.cpp | 10 +- .../pipeline/common/runtime_filter_consumer.h | 6 +- be/src/pipeline/exec/datagen_operator.cpp | 2 +- .../pipeline/exec/join_build_sink_operator.h | 6 +- be/src/runtime/fragment_mgr.cpp | 8 +- be/src/runtime/runtime_filter_mgr.cpp | 49 ++--- be/src/runtime/runtime_filter_mgr.h | 17 +- be/src/runtime/runtime_state.cpp | 13 +- be/src/runtime/runtime_state.h | 4 +- be/test/exprs/runtime_filter_test.cpp | 14 +- 16 files changed, 203 insertions(+), 262 deletions(-) diff --git a/be/src/exprs/create_predicate_function.h b/be/src/exprs/create_predicate_function.h index 11889ff2ec349b..4808caa00f37d0 100644 --- a/be/src/exprs/create_predicate_function.h +++ b/be/src/exprs/create_predicate_function.h @@ -34,7 +34,9 @@ class MinmaxFunctionTraits { using BasePtr = MinMaxFuncBase*; template static BasePtr get_function() { - return new MinMaxNumFunc::CppType>(); + using CppType = typename PrimitiveTypeTraits::CppType; + return new MinMaxNumFunc< + std::conditional_t, std::string, CppType>>(); } }; diff --git a/be/src/exprs/minmax_predicate.h b/be/src/exprs/minmax_predicate.h index b4291e2edb7e6b..377b33696c82b9 100644 --- a/be/src/exprs/minmax_predicate.h +++ b/be/src/exprs/minmax_predicate.h @@ -26,6 +26,7 @@ #include "vec/columns/column_nullable.h" #include "vec/columns/column_string.h" #include "vec/common/assert_cast.h" +#include "vec/common/string_ref.h" namespace doris { // only used in Runtime Filter @@ -75,19 +76,22 @@ class MinMaxNumFunc : public MinMaxFuncBase { for (size_t i = start; i < size; i++) { if (nullmap == nullptr || !nullmap[i]) { if constexpr (NeedMin) { - _min = std::min(_min, column_string.get_data_at(i)); + if (column_string.get_data_at(i) < StringRef(_min)) { + _min = column_string.get_data_at(i).to_string(); + } } if constexpr (NeedMax) { - _max = std::max(_max, column_string.get_data_at(i)); + if (column_string.get_data_at(i) > StringRef(_max)) { + _max = column_string.get_data_at(i).to_string(); + } } } } - store_string_ref(); } void update_batch(const vectorized::ColumnPtr& column, size_t start) { const auto size = column->size(); - if constexpr (std::is_same_v) { + if constexpr (std::is_same_v) { if (column->is_column_string64()) { _update_batch_string(assert_cast(*column), nullptr, start, size); @@ -111,7 +115,7 @@ class MinMaxNumFunc : public MinMaxFuncBase { void update_batch(const vectorized::ColumnPtr& column, const vectorized::NullMap& nullmap, size_t start) { const auto size = column->size(); - if constexpr (std::is_same_v) { + if constexpr (std::is_same_v) { if (column->is_column_string64()) { _update_batch_string(assert_cast(*column), nullmap.data(), start, size); @@ -135,26 +139,15 @@ class MinMaxNumFunc : public MinMaxFuncBase { } Status merge(MinMaxFuncBase* minmax_func) override { - if constexpr (std::is_same_v) { - auto* other_minmax = static_cast*>(minmax_func); - if constexpr (NeedMin) { - _min = std::min(_min, other_minmax->_min); - } - if constexpr (NeedMax) { - _max = std::max(_max, other_minmax->_max); - } - store_string_ref(); - } else { - auto* other_minmax = static_cast*>(minmax_func); - if constexpr (NeedMin) { - if (other_minmax->_min < _min) { - _min = other_minmax->_min; - } + auto* other_minmax = static_cast*>(minmax_func); + if constexpr (NeedMin) { + if (other_minmax->_min < _min) { + _min = other_minmax->_min; } - if constexpr (NeedMax) { - if (other_minmax->_max > _max) { - _max = other_minmax->_max; - } + } + if constexpr (NeedMax) { + if (other_minmax->_max > _max) { + _max = other_minmax->_max; } } @@ -172,28 +165,9 @@ class MinMaxNumFunc : public MinMaxFuncBase { return Status::OK(); } - void store_string_ref() { - if constexpr (std::is_same_v) { - if constexpr (NeedMin) { - if (_min.data != _stored_min.data()) { - _stored_min = _min.to_string(); - _min = StringRef(_stored_min); - } - } - if constexpr (NeedMax) { - if (_max.data != _stored_max.data()) { - _stored_max = _max.to_string(); - _max = StringRef(_stored_max); - } - } - } - } - protected: T _max = type_limit::min(); T _min = type_limit::max(); - std::string _stored_min; - std::string _stored_max; }; template diff --git a/be/src/exprs/runtime_filter.cpp b/be/src/exprs/runtime_filter.cpp index be993374cbfc5c..f61cebc8c054bc 100644 --- a/be/src/exprs/runtime_filter.cpp +++ b/be/src/exprs/runtime_filter.cpp @@ -35,7 +35,6 @@ #include "agent/be_exec_version_manager.h" #include "common/logging.h" -#include "common/object_pool.h" #include "common/status.h" #include "exprs/bitmapfilter_predicate.h" #include "exprs/bloom_filter_func.h" @@ -281,15 +280,13 @@ Status create_vbin_predicate(const TypeDescriptor& type, TExprOpcode::type opcod // This class is a wrapper of runtime predicate function class RuntimePredicateWrapper { public: - RuntimePredicateWrapper(ObjectPool* pool, const RuntimeFilterParams* params) - : RuntimePredicateWrapper(pool, params->column_return_type, params->filter_type, + RuntimePredicateWrapper(const RuntimeFilterParams* params) + : RuntimePredicateWrapper(params->column_return_type, params->filter_type, params->filter_id) {}; // for a 'tmp' runtime predicate wrapper // only could called assign method or as a param for merge - RuntimePredicateWrapper(ObjectPool* pool, PrimitiveType column_type, RuntimeFilterType type, - uint32_t filter_id) - : _pool(pool), - _column_return_type(column_type), + RuntimePredicateWrapper(PrimitiveType column_type, RuntimeFilterType type, uint32_t filter_id) + : _column_return_type(column_type), _filter_type(type), _context(new RuntimeFilterContext()), _filter_id(filter_id) {} @@ -566,51 +563,45 @@ class RuntimePredicateWrapper { switch (type) { case TYPE_BOOLEAN: { - batch_assign(in_filter, [](std::shared_ptr& set, PColumnValue& column, - ObjectPool* pool) { + batch_assign(in_filter, [](std::shared_ptr& set, PColumnValue& column) { bool bool_val = column.boolval(); set->insert(&bool_val); }); break; } case TYPE_TINYINT: { - batch_assign(in_filter, [](std::shared_ptr& set, PColumnValue& column, - ObjectPool* pool) { - int8_t int_val = static_cast(column.intval()); + batch_assign(in_filter, [](std::shared_ptr& set, PColumnValue& column) { + auto int_val = static_cast(column.intval()); set->insert(&int_val); }); break; } case TYPE_SMALLINT: { - batch_assign(in_filter, [](std::shared_ptr& set, PColumnValue& column, - ObjectPool* pool) { - int16_t int_val = static_cast(column.intval()); + batch_assign(in_filter, [](std::shared_ptr& set, PColumnValue& column) { + auto int_val = static_cast(column.intval()); set->insert(&int_val); }); break; } case TYPE_INT: { - batch_assign(in_filter, [](std::shared_ptr& set, PColumnValue& column, - ObjectPool* pool) { + batch_assign(in_filter, [](std::shared_ptr& set, PColumnValue& column) { int32_t int_val = column.intval(); set->insert(&int_val); }); break; } case TYPE_BIGINT: { - batch_assign(in_filter, [](std::shared_ptr& set, PColumnValue& column, - ObjectPool* pool) { + batch_assign(in_filter, [](std::shared_ptr& set, PColumnValue& column) { int64_t long_val = column.longval(); set->insert(&long_val); }); break; } case TYPE_LARGEINT: { - batch_assign(in_filter, [](std::shared_ptr& set, PColumnValue& column, - ObjectPool* pool) { + batch_assign(in_filter, [](std::shared_ptr& set, PColumnValue& column) { auto string_val = column.stringval(); StringParser::ParseResult result; - int128_t int128_val = StringParser::string_to_int( + auto int128_val = StringParser::string_to_int( string_val.c_str(), string_val.length(), &result); DCHECK(result == StringParser::PARSE_SUCCESS); set->insert(&int128_val); @@ -618,32 +609,28 @@ class RuntimePredicateWrapper { break; } case TYPE_FLOAT: { - batch_assign(in_filter, [](std::shared_ptr& set, PColumnValue& column, - ObjectPool* pool) { - float float_val = static_cast(column.doubleval()); + batch_assign(in_filter, [](std::shared_ptr& set, PColumnValue& column) { + auto float_val = static_cast(column.doubleval()); set->insert(&float_val); }); break; } case TYPE_DOUBLE: { - batch_assign(in_filter, [](std::shared_ptr& set, PColumnValue& column, - ObjectPool* pool) { + batch_assign(in_filter, [](std::shared_ptr& set, PColumnValue& column) { double double_val = column.doubleval(); set->insert(&double_val); }); break; } case TYPE_DATEV2: { - batch_assign(in_filter, [](std::shared_ptr& set, PColumnValue& column, - ObjectPool* pool) { + batch_assign(in_filter, [](std::shared_ptr& set, PColumnValue& column) { auto date_v2_val = column.intval(); set->insert(&date_v2_val); }); break; } case TYPE_DATETIMEV2: { - batch_assign(in_filter, [](std::shared_ptr& set, PColumnValue& column, - ObjectPool* pool) { + batch_assign(in_filter, [](std::shared_ptr& set, PColumnValue& column) { auto date_v2_val = column.longval(); set->insert(&date_v2_val); }); @@ -651,9 +638,8 @@ class RuntimePredicateWrapper { } case TYPE_DATETIME: case TYPE_DATE: { - batch_assign(in_filter, [](std::shared_ptr& set, PColumnValue& column, - ObjectPool* pool) { - auto& string_val_ref = column.stringval(); + batch_assign(in_filter, [](std::shared_ptr& set, PColumnValue& column) { + const auto& string_val_ref = column.stringval(); VecDateTimeValue datetime_val; datetime_val.from_date_str(string_val_ref.c_str(), string_val_ref.length()); set->insert(&datetime_val); @@ -661,36 +647,32 @@ class RuntimePredicateWrapper { break; } case TYPE_DECIMALV2: { - batch_assign(in_filter, [](std::shared_ptr& set, PColumnValue& column, - ObjectPool* pool) { - auto& string_val_ref = column.stringval(); + batch_assign(in_filter, [](std::shared_ptr& set, PColumnValue& column) { + const auto& string_val_ref = column.stringval(); DecimalV2Value decimal_val(string_val_ref); set->insert(&decimal_val); }); break; } case TYPE_DECIMAL32: { - batch_assign(in_filter, [](std::shared_ptr& set, PColumnValue& column, - ObjectPool* pool) { + batch_assign(in_filter, [](std::shared_ptr& set, PColumnValue& column) { int32_t decimal_32_val = column.intval(); set->insert(&decimal_32_val); }); break; } case TYPE_DECIMAL64: { - batch_assign(in_filter, [](std::shared_ptr& set, PColumnValue& column, - ObjectPool* pool) { + batch_assign(in_filter, [](std::shared_ptr& set, PColumnValue& column) { int64_t decimal_64_val = column.longval(); set->insert(&decimal_64_val); }); break; } case TYPE_DECIMAL128I: { - batch_assign(in_filter, [](std::shared_ptr& set, PColumnValue& column, - ObjectPool* pool) { + batch_assign(in_filter, [](std::shared_ptr& set, PColumnValue& column) { auto string_val = column.stringval(); StringParser::ParseResult result; - int128_t int128_val = StringParser::string_to_int( + auto int128_val = StringParser::string_to_int( string_val.c_str(), string_val.length(), &result); DCHECK(result == StringParser::PARSE_SUCCESS); set->insert(&int128_val); @@ -698,8 +680,7 @@ class RuntimePredicateWrapper { break; } case TYPE_DECIMAL256: { - batch_assign(in_filter, [](std::shared_ptr& set, PColumnValue& column, - ObjectPool* pool) { + batch_assign(in_filter, [](std::shared_ptr& set, PColumnValue& column) { auto string_val = column.stringval(); StringParser::ParseResult result; auto int_val = StringParser::string_to_int( @@ -712,12 +693,9 @@ class RuntimePredicateWrapper { case TYPE_VARCHAR: case TYPE_CHAR: case TYPE_STRING: { - batch_assign(in_filter, [](std::shared_ptr& set, PColumnValue& column, - ObjectPool* pool) { - auto& string_val_ref = column.stringval(); - auto val_ptr = pool->add(new std::string(string_val_ref)); - StringRef string_val(val_ptr->c_str(), val_ptr->length()); - set->insert(&string_val); + batch_assign(in_filter, [](std::shared_ptr& set, PColumnValue& column) { + const auto& string_val_ref = column.stringval(); + set->insert(&string_val_ref); }); break; } @@ -761,13 +739,13 @@ class RuntimePredicateWrapper { return _context->minmax_func->assign(&min_val, &max_val); } case TYPE_TINYINT: { - int8_t min_val = static_cast(minmax_filter->min_val().intval()); - int8_t max_val = static_cast(minmax_filter->max_val().intval()); + auto min_val = static_cast(minmax_filter->min_val().intval()); + auto max_val = static_cast(minmax_filter->max_val().intval()); return _context->minmax_func->assign(&min_val, &max_val); } case TYPE_SMALLINT: { - int16_t min_val = static_cast(minmax_filter->min_val().intval()); - int16_t max_val = static_cast(minmax_filter->max_val().intval()); + auto min_val = static_cast(minmax_filter->min_val().intval()); + auto max_val = static_cast(minmax_filter->max_val().intval()); return _context->minmax_func->assign(&min_val, &max_val); } case TYPE_INT: { @@ -784,22 +762,22 @@ class RuntimePredicateWrapper { auto min_string_val = minmax_filter->min_val().stringval(); auto max_string_val = minmax_filter->max_val().stringval(); StringParser::ParseResult result; - int128_t min_val = StringParser::string_to_int( - min_string_val.c_str(), min_string_val.length(), &result); + auto min_val = StringParser::string_to_int(min_string_val.c_str(), + min_string_val.length(), &result); DCHECK(result == StringParser::PARSE_SUCCESS); - int128_t max_val = StringParser::string_to_int( - max_string_val.c_str(), max_string_val.length(), &result); + auto max_val = StringParser::string_to_int(max_string_val.c_str(), + max_string_val.length(), &result); DCHECK(result == StringParser::PARSE_SUCCESS); return _context->minmax_func->assign(&min_val, &max_val); } case TYPE_FLOAT: { - float min_val = static_cast(minmax_filter->min_val().doubleval()); - float max_val = static_cast(minmax_filter->max_val().doubleval()); + auto min_val = static_cast(minmax_filter->min_val().doubleval()); + auto max_val = static_cast(minmax_filter->max_val().doubleval()); return _context->minmax_func->assign(&min_val, &max_val); } case TYPE_DOUBLE: { - double min_val = static_cast(minmax_filter->min_val().doubleval()); - double max_val = static_cast(minmax_filter->max_val().doubleval()); + auto min_val = static_cast(minmax_filter->min_val().doubleval()); + auto max_val = static_cast(minmax_filter->max_val().doubleval()); return _context->minmax_func->assign(&min_val, &max_val); } case TYPE_DATEV2: { @@ -814,8 +792,8 @@ class RuntimePredicateWrapper { } case TYPE_DATETIME: case TYPE_DATE: { - auto& min_val_ref = minmax_filter->min_val().stringval(); - auto& max_val_ref = minmax_filter->max_val().stringval(); + const auto& min_val_ref = minmax_filter->min_val().stringval(); + const auto& max_val_ref = minmax_filter->max_val().stringval(); VecDateTimeValue min_val; VecDateTimeValue max_val; min_val.from_date_str(min_val_ref.c_str(), min_val_ref.length()); @@ -823,8 +801,8 @@ class RuntimePredicateWrapper { return _context->minmax_func->assign(&min_val, &max_val); } case TYPE_DECIMALV2: { - auto& min_val_ref = minmax_filter->min_val().stringval(); - auto& max_val_ref = minmax_filter->max_val().stringval(); + const auto& min_val_ref = minmax_filter->min_val().stringval(); + const auto& max_val_ref = minmax_filter->max_val().stringval(); DecimalV2Value min_val(min_val_ref); DecimalV2Value max_val(max_val_ref); return _context->minmax_func->assign(&min_val, &max_val); @@ -843,11 +821,11 @@ class RuntimePredicateWrapper { auto min_string_val = minmax_filter->min_val().stringval(); auto max_string_val = minmax_filter->max_val().stringval(); StringParser::ParseResult result; - int128_t min_val = StringParser::string_to_int( - min_string_val.c_str(), min_string_val.length(), &result); + auto min_val = StringParser::string_to_int(min_string_val.c_str(), + min_string_val.length(), &result); DCHECK(result == StringParser::PARSE_SUCCESS); - int128_t max_val = StringParser::string_to_int( - max_string_val.c_str(), max_string_val.length(), &result); + auto max_val = StringParser::string_to_int(max_string_val.c_str(), + max_string_val.length(), &result); DCHECK(result == StringParser::PARSE_SUCCESS); return _context->minmax_func->assign(&min_val, &max_val); } @@ -866,13 +844,9 @@ class RuntimePredicateWrapper { case TYPE_VARCHAR: case TYPE_CHAR: case TYPE_STRING: { - auto& min_val_ref = minmax_filter->min_val().stringval(); - auto& max_val_ref = minmax_filter->max_val().stringval(); - auto min_val_ptr = _pool->add(new std::string(min_val_ref)); - auto max_val_ptr = _pool->add(new std::string(max_val_ref)); - StringRef min_val(min_val_ptr->c_str(), min_val_ptr->length()); - StringRef max_val(max_val_ptr->c_str(), max_val_ptr->length()); - return _context->minmax_func->assign(&min_val, &max_val); + auto min_val_ref = minmax_filter->min_val().stringval(); + auto max_val_ref = minmax_filter->max_val().stringval(); + return _context->minmax_func->assign(&min_val_ref, &max_val_ref); } default: break; @@ -915,10 +889,10 @@ class RuntimePredicateWrapper { void batch_assign(const PInFilter* filter, void (*assign_func)(std::shared_ptr& _hybrid_set, - PColumnValue&, ObjectPool*)) { + PColumnValue&)) { for (int i = 0; i < filter->values_size(); ++i) { PColumnValue column = filter->values(i); - assign_func(_context->hybrid_set, column, _pool); + assign_func(_context->hybrid_set, column); } } @@ -945,8 +919,6 @@ class RuntimePredicateWrapper { } private: - ObjectPool* _pool; - // When a runtime filter received from remote and it is a bloom filter, _column_return_type will be invalid. PrimitiveType _column_return_type; // column type RuntimeFilterType _filter_type; @@ -956,11 +928,11 @@ class RuntimePredicateWrapper { uint32_t _filter_id; }; -Status IRuntimeFilter::create(RuntimeFilterParamsContext* state, ObjectPool* pool, - const TRuntimeFilterDesc* desc, const TQueryOptions* query_options, - const RuntimeFilterRole role, int node_id, IRuntimeFilter** res, +Status IRuntimeFilter::create(RuntimeFilterParamsContext* state, const TRuntimeFilterDesc* desc, + const TQueryOptions* query_options, const RuntimeFilterRole role, + int node_id, std::shared_ptr* res, bool build_bf_exactly, bool need_local_merge) { - *res = pool->add(new IRuntimeFilter(state, pool, desc, need_local_merge)); + *res = std::make_shared(state, desc, need_local_merge); (*res)->set_role(role); return (*res)->init_with_desc(desc, query_options, node_id, build_bf_exactly); } @@ -983,12 +955,12 @@ Status IRuntimeFilter::publish(bool publish_local) { RETURN_IF_ERROR(_state->runtime_filter_mgr->get_merge_addr(&addr)); return filter->push_to_remote(&addr); }; - auto send_to_local = [&](RuntimePredicateWrapper* wrapper) { - std::vector filters; + auto send_to_local = [&](std::shared_ptr wrapper) { + std::vector> filters; RETURN_IF_ERROR(_state->runtime_filter_mgr->get_consume_filters(_filter_id, filters)); DCHECK(!filters.empty()); // push down - for (auto* filter : filters) { + for (auto filter : filters) { filter->_wrapper = wrapper; filter->update_runtime_filter_type_to_profile(); filter->signal(); @@ -1000,13 +972,13 @@ Status IRuntimeFilter::publish(bool publish_local) { RETURN_IF_ERROR(_state->runtime_filter_mgr->get_local_merge_producer_filters( _filter_id, &local_merge_filters)); std::lock_guard l(*local_merge_filters->lock); - RETURN_IF_ERROR(local_merge_filters->filters[0]->merge_from(_wrapper)); + RETURN_IF_ERROR(local_merge_filters->filters[0]->merge_from(_wrapper.get())); local_merge_filters->merge_time--; if (local_merge_filters->merge_time == 0) { if (_has_local_target) { RETURN_IF_ERROR(send_to_local(local_merge_filters->filters[0]->_wrapper)); } else { - RETURN_IF_ERROR(send_to_remote(local_merge_filters->filters[0])); + RETURN_IF_ERROR(send_to_remote(local_merge_filters->filters[0].get())); } } return Status::OK(); @@ -1081,7 +1053,7 @@ Status IRuntimeFilter::send_filter_size(RuntimeState* state, uint64_t local_filt return Status::OK(); } else { if (_has_local_target) { - for (auto* filter : local_merge_filters->filters) { + for (auto filter : local_merge_filters->filters) { filter->set_synced_size(local_merge_filters->local_merged_size); } return Status::OK(); @@ -1382,7 +1354,7 @@ Status IRuntimeFilter::init_with_desc(const TRuntimeFilterDesc* desc, const TQue _probe_expr = iter->second; } - _wrapper = _pool->add(new RuntimePredicateWrapper(_pool, ¶ms)); + _wrapper = std::make_shared(¶ms); return _wrapper->init(¶ms); } @@ -1398,22 +1370,22 @@ Status IRuntimeFilter::serialize(PPublishFilterRequestV2* request, void** data, return serialize_impl(request, data, len); } -Status IRuntimeFilter::create_wrapper(const MergeRuntimeFilterParams* param, ObjectPool* pool, +Status IRuntimeFilter::create_wrapper(const MergeRuntimeFilterParams* param, std::unique_ptr* wrapper) { - return _create_wrapper(param, pool, wrapper); + return _create_wrapper(param, wrapper); } -Status IRuntimeFilter::create_wrapper(const UpdateRuntimeFilterParams* param, ObjectPool* pool, +Status IRuntimeFilter::create_wrapper(const UpdateRuntimeFilterParams* param, std::unique_ptr* wrapper) { - return _create_wrapper(param, pool, wrapper); + return _create_wrapper(param, wrapper); } Status IRuntimeFilter::create_wrapper(const UpdateRuntimeFilterParamsV2* param, - RuntimePredicateWrapper** wrapper) { + std::shared_ptr* wrapper) { auto filter_type = param->request->filter_type(); PrimitiveType column_type = param->column_type; - *wrapper = param->pool->add(new RuntimePredicateWrapper( - param->pool, column_type, get_type(filter_type), param->request->filter_id())); + *wrapper = std::make_shared(column_type, get_type(filter_type), + param->request->filter_id()); if (param->request->has_ignored() && param->request->ignored()) { (*wrapper)->set_ignored(); @@ -1451,7 +1423,7 @@ Status IRuntimeFilter::init_bloom_filter(const size_t build_bf_cardinality) { } template -Status IRuntimeFilter::_create_wrapper(const T* param, ObjectPool* pool, +Status IRuntimeFilter::_create_wrapper(const T* param, std::unique_ptr* wrapper) { int filter_type = param->request->filter_type(); PrimitiveType column_type = PrimitiveType::INVALID_TYPE; @@ -1461,7 +1433,7 @@ Status IRuntimeFilter::_create_wrapper(const T* param, ObjectPool* pool, if (param->request->has_column_type()) { column_type = to_primitive_type(param->request->column_type()); } - *wrapper = std::make_unique(pool, column_type, get_type(filter_type), + *wrapper = std::make_unique(column_type, get_type(filter_type), param->request->filter_id()); if (param->request->has_ignored() && param->request->ignored()) { @@ -1562,7 +1534,7 @@ void IRuntimeFilter::to_protobuf(PInFilter* filter) { auto column_type = _wrapper->column_type(); filter->set_column_type(to_proto(column_type)); - auto it = _wrapper->get_in_filter_iterator(); + auto* it = _wrapper->get_in_filter_iterator(); DCHECK(it != nullptr); switch (column_type) { @@ -1673,8 +1645,8 @@ void IRuntimeFilter::to_protobuf(PInFilter* filter) { case TYPE_CHAR: case TYPE_VARCHAR: case TYPE_STRING: { - batch_copy(filter, it, [](PColumnValue* column, const StringRef* value) { - column->set_stringval(std::string(value->data, value->size)); + batch_copy(filter, it, [](PColumnValue* column, const std::string* value) { + column->set_stringval(*value); }); return; } @@ -1788,12 +1760,10 @@ void IRuntimeFilter::to_protobuf(PMinMaxFilter* filter) { case TYPE_CHAR: case TYPE_VARCHAR: case TYPE_STRING: { - const StringRef* min_string_value = reinterpret_cast(min_data); - filter->mutable_min_val()->set_stringval( - std::string(min_string_value->data, min_string_value->size)); - const StringRef* max_string_value = reinterpret_cast(max_data); - filter->mutable_max_val()->set_stringval( - std::string(max_string_value->data, max_string_value->size)); + const auto* min_string_value = reinterpret_cast(min_data); + filter->mutable_min_val()->set_stringval(*min_string_value); + const auto* max_string_value = reinterpret_cast(max_data); + filter->mutable_max_val()->set_stringval(*max_string_value); break; } default: { @@ -1820,7 +1790,7 @@ Status IRuntimeFilter::update_filter(const UpdateRuntimeFilterParams* param) { set_ignored(); } else { std::unique_ptr wrapper; - RETURN_IF_ERROR(IRuntimeFilter::create_wrapper(param, _pool, &wrapper)); + RETURN_IF_ERROR(IRuntimeFilter::create_wrapper(param, &wrapper)); RETURN_IF_ERROR(_wrapper->merge(wrapper.get())); update_runtime_filter_type_to_profile(); } @@ -1829,8 +1799,8 @@ Status IRuntimeFilter::update_filter(const UpdateRuntimeFilterParams* param) { return Status::OK(); } -void IRuntimeFilter::update_filter(RuntimePredicateWrapper* wrapper, int64_t merge_time, - int64_t start_apply) { +void IRuntimeFilter::update_filter(std::shared_ptr wrapper, + int64_t merge_time, int64_t start_apply) { _profile->add_info_string("UpdateTime", std::to_string(MonotonicMillis() - start_apply) + " ms"); _profile->add_info_string("MergeTime", std::to_string(merge_time) + " ms"); diff --git a/be/src/exprs/runtime_filter.h b/be/src/exprs/runtime_filter.h index 9f0ba786238135..9bf27025876f15 100644 --- a/be/src/exprs/runtime_filter.h +++ b/be/src/exprs/runtime_filter.h @@ -48,7 +48,6 @@ class IOBufAsZeroCopyInputStream; } namespace doris { -class ObjectPool; class RuntimePredicateWrapper; class PPublishFilterRequest; class PPublishFilterRequestV2; @@ -158,17 +157,15 @@ struct RuntimeFilterFuncBase { struct UpdateRuntimeFilterParams { UpdateRuntimeFilterParams(const PPublishFilterRequest* req, - butil::IOBufAsZeroCopyInputStream* data_stream, ObjectPool* obj_pool) - : request(req), data(data_stream), pool(obj_pool) {} + butil::IOBufAsZeroCopyInputStream* data_stream) + : request(req), data(data_stream) {} const PPublishFilterRequest* request = nullptr; butil::IOBufAsZeroCopyInputStream* data = nullptr; - ObjectPool* pool = nullptr; }; struct UpdateRuntimeFilterParamsV2 { const PPublishFilterRequestV2* request; butil::IOBufAsZeroCopyInputStream* data; - ObjectPool* pool = nullptr; PrimitiveType column_type = INVALID_TYPE; }; @@ -193,10 +190,9 @@ enum RuntimeFilterState { /// that can be pushed down to node based on the results of the right table. class IRuntimeFilter { public: - IRuntimeFilter(RuntimeFilterParamsContext* state, ObjectPool* pool, - const TRuntimeFilterDesc* desc, bool need_local_merge = false) + IRuntimeFilter(RuntimeFilterParamsContext* state, const TRuntimeFilterDesc* desc, + bool need_local_merge = false) : _state(state), - _pool(pool), _filter_id(desc->filter_id), _is_broadcast_join(true), _has_remote_target(false), @@ -216,9 +212,9 @@ class IRuntimeFilter { ~IRuntimeFilter() = default; - static Status create(RuntimeFilterParamsContext* state, ObjectPool* pool, - const TRuntimeFilterDesc* desc, const TQueryOptions* query_options, - const RuntimeFilterRole role, int node_id, IRuntimeFilter** res, + static Status create(RuntimeFilterParamsContext* state, const TRuntimeFilterDesc* desc, + const TQueryOptions* query_options, const RuntimeFilterRole role, + int node_id, std::shared_ptr* res, bool build_bf_exactly = false, bool need_local_merge = false); RuntimeFilterContextSPtr& get_shared_context_ref(); @@ -282,17 +278,17 @@ class IRuntimeFilter { Status merge_from(const RuntimePredicateWrapper* wrapper); - static Status create_wrapper(const MergeRuntimeFilterParams* param, ObjectPool* pool, + static Status create_wrapper(const MergeRuntimeFilterParams* param, std::unique_ptr* wrapper); - static Status create_wrapper(const UpdateRuntimeFilterParams* param, ObjectPool* pool, + static Status create_wrapper(const UpdateRuntimeFilterParams* param, std::unique_ptr* wrapper); static Status create_wrapper(const UpdateRuntimeFilterParamsV2* param, - RuntimePredicateWrapper** wrapper); + std::shared_ptr* wrapper); Status change_to_bloom_filter(); Status init_bloom_filter(const size_t build_bf_cardinality); Status update_filter(const UpdateRuntimeFilterParams* param); - void update_filter(RuntimePredicateWrapper* filter_wrapper, int64_t merge_time, + void update_filter(std::shared_ptr filter_wrapper, int64_t merge_time, int64_t start_apply); void set_ignored(); @@ -382,7 +378,7 @@ class IRuntimeFilter { Status serialize_impl(T* request, void** data, int* len); template - static Status _create_wrapper(const T* param, ObjectPool* pool, + static Status _create_wrapper(const T* param, std::unique_ptr* wrapper); void _set_push_down(bool push_down) { _is_push_down = push_down; } @@ -396,10 +392,8 @@ class IRuntimeFilter { } RuntimeFilterParamsContext* _state = nullptr; - ObjectPool* _pool = nullptr; // _wrapper is a runtime filter function wrapper - // _wrapper should alloc from _pool - RuntimePredicateWrapper* _wrapper = nullptr; + std::shared_ptr _wrapper; // runtime filter id int _filter_id; // Specific types BoardCast or Shuffle diff --git a/be/src/exprs/runtime_filter_slots.h b/be/src/exprs/runtime_filter_slots.h index ebda4b56fcc30e..c0a249cd6b063d 100644 --- a/be/src/exprs/runtime_filter_slots.h +++ b/be/src/exprs/runtime_filter_slots.h @@ -34,10 +34,10 @@ class VRuntimeFilterSlots { public: VRuntimeFilterSlots( const std::vector>& build_expr_ctxs, - const std::vector& runtime_filters) + const std::vector>& runtime_filters) : _build_expr_context(build_expr_ctxs), _runtime_filters(runtime_filters) { - for (auto* runtime_filter : _runtime_filters) { - _runtime_filters_map[runtime_filter->expr_order()].push_back(runtime_filter); + for (auto runtime_filter : _runtime_filters) { + _runtime_filters_map[runtime_filter->expr_order()].push_back(runtime_filter.get()); } } @@ -46,14 +46,14 @@ class VRuntimeFilterSlots { if (_runtime_filters.empty()) { return Status::OK(); } - for (auto* runtime_filter : _runtime_filters) { + for (auto runtime_filter : _runtime_filters) { if (runtime_filter->need_sync_filter_size()) { runtime_filter->set_dependency(dependency); } } // send_filter_size may call dependency->sub(), so we call set_dependency firstly for all rf to avoid dependency set_ready repeatedly - for (auto* runtime_filter : _runtime_filters) { + for (auto runtime_filter : _runtime_filters) { if (runtime_filter->need_sync_filter_size()) { RETURN_IF_ERROR(runtime_filter->send_filter_size(state, hash_table_size)); } @@ -70,7 +70,7 @@ class VRuntimeFilterSlots { Status ignore_filters(RuntimeState* state) { // process ignore duplicate IN_FILTER std::unordered_set has_in_filter; - for (auto* filter : _runtime_filters) { + for (auto filter : _runtime_filters) { if (filter->get_ignored()) { continue; } @@ -85,7 +85,7 @@ class VRuntimeFilterSlots { } // process ignore filter when it has IN_FILTER on same expr, and init bloom filter size - for (auto* filter : _runtime_filters) { + for (auto filter : _runtime_filters) { if (filter->get_ignored()) { continue; } @@ -100,12 +100,13 @@ class VRuntimeFilterSlots { Status init_filters(RuntimeState* state, uint64_t local_hash_table_size) { // process IN_OR_BLOOM_FILTER's real type - for (auto* filter : _runtime_filters) { + for (auto filter : _runtime_filters) { if (filter->get_ignored()) { continue; } if (filter->type() == RuntimeFilterType::IN_OR_BLOOM_FILTER && - get_real_size(filter, local_hash_table_size) > state->runtime_filter_max_in_num()) { + get_real_size(filter.get(), local_hash_table_size) > + state->runtime_filter_max_in_num()) { RETURN_IF_ERROR(filter->change_to_bloom_filter()); } @@ -114,8 +115,8 @@ class VRuntimeFilterSlots { return Status::InternalError("sync filter size meet error, filter: {}", filter->debug_string()); } - RETURN_IF_ERROR( - filter->init_bloom_filter(get_real_size(filter, local_hash_table_size))); + RETURN_IF_ERROR(filter->init_bloom_filter( + get_real_size(filter.get(), local_hash_table_size))); } } return Status::OK(); @@ -175,7 +176,7 @@ class VRuntimeFilterSlots { private: const std::vector>& _build_expr_context; - std::vector _runtime_filters; + std::vector> _runtime_filters; // prob_contition index -> [IRuntimeFilter] std::map> _runtime_filters_map; }; diff --git a/be/src/exprs/runtime_filter_slots_cross.h b/be/src/exprs/runtime_filter_slots_cross.h index 1d496ddf5571e6..01ae21a75992de 100644 --- a/be/src/exprs/runtime_filter_slots_cross.h +++ b/be/src/exprs/runtime_filter_slots_cross.h @@ -17,6 +17,7 @@ #pragma once +#include #include #include "common/status.h" @@ -34,14 +35,14 @@ namespace doris { // this class used in cross join node class VRuntimeFilterSlotsCross { public: - VRuntimeFilterSlotsCross(const std::vector& runtime_filters, - const vectorized::VExprContextSPtrs& src_expr_ctxs) - : _runtime_filters(runtime_filters), filter_src_expr_ctxs(src_expr_ctxs) {} + VRuntimeFilterSlotsCross(const std::vector>& runtime_filters, + vectorized::VExprContextSPtrs src_expr_ctxs) + : _runtime_filters(runtime_filters), filter_src_expr_ctxs(std::move(src_expr_ctxs)) {} ~VRuntimeFilterSlotsCross() = default; Status init(RuntimeState* state) { - for (auto* runtime_filter : _runtime_filters) { + for (auto runtime_filter : _runtime_filters) { if (runtime_filter == nullptr) { return Status::InternalError("runtime filter is nullptr"); } @@ -56,7 +57,7 @@ class VRuntimeFilterSlotsCross { Status insert(vectorized::Block* block) { for (int i = 0; i < _runtime_filters.size(); ++i) { - auto* filter = _runtime_filters[i]; + auto filter = _runtime_filters[i]; const auto& vexpr_ctx = filter_src_expr_ctxs[i]; int result_column_id = -1; @@ -72,7 +73,7 @@ class VRuntimeFilterSlotsCross { } Status publish() { - for (auto& filter : _runtime_filters) { + for (auto filter : _runtime_filters) { RETURN_IF_ERROR(filter->publish()); } return Status::OK(); @@ -81,7 +82,7 @@ class VRuntimeFilterSlotsCross { bool empty() const { return _runtime_filters.empty(); } private: - const std::vector& _runtime_filters; + const std::vector>& _runtime_filters; const vectorized::VExprContextSPtrs filter_src_expr_ctxs; }; diff --git a/be/src/pipeline/common/runtime_filter_consumer.cpp b/be/src/pipeline/common/runtime_filter_consumer.cpp index 0e9c2d0f304c79..57397efd21185f 100644 --- a/be/src/pipeline/common/runtime_filter_consumer.cpp +++ b/be/src/pipeline/common/runtime_filter_consumer.cpp @@ -52,7 +52,7 @@ Status RuntimeFilterConsumer::_register_runtime_filter(bool need_local_merge) { _runtime_filter_ctxs.reserve(filter_size); _runtime_filter_ready_flag.reserve(filter_size); for (int i = 0; i < filter_size; ++i) { - IRuntimeFilter* runtime_filter = nullptr; + std::shared_ptr runtime_filter; const auto& filter_desc = _runtime_filter_descs[i]; RETURN_IF_ERROR(_state->register_consumer_runtime_filter(filter_desc, need_local_merge, _filter_id, &runtime_filter)); @@ -73,9 +73,9 @@ void RuntimeFilterConsumer::init_runtime_filter_dependency( local_runtime_filter_dependencies; for (size_t i = 0; i < _runtime_filter_descs.size(); ++i) { - IRuntimeFilter* runtime_filter = _runtime_filter_ctxs[i].runtime_filter; + auto runtime_filter = _runtime_filter_ctxs[i].runtime_filter; runtime_filter_dependencies[i] = std::make_shared( - id, node_id, name, runtime_filter); + id, node_id, name, runtime_filter.get()); _runtime_filter_ctxs[i].runtime_filter_dependency = runtime_filter_dependencies[i].get(); runtime_filter_timers[i] = std::make_shared( runtime_filter->registration_time(), runtime_filter->wait_time_ms(), @@ -89,7 +89,7 @@ void RuntimeFilterConsumer::init_runtime_filter_dependency( // The gloabl runtime filter timer need set local runtime filter dependencies. // start to wait before the local runtime filter ready for (size_t i = 0; i < _runtime_filter_descs.size(); ++i) { - IRuntimeFilter* runtime_filter = _runtime_filter_ctxs[i].runtime_filter; + auto runtime_filter = _runtime_filter_ctxs[i].runtime_filter; if (!runtime_filter->has_local_target()) { runtime_filter_timers[i]->set_local_runtime_filter_dependencies( local_runtime_filter_dependencies); @@ -105,7 +105,7 @@ Status RuntimeFilterConsumer::_acquire_runtime_filter(bool pipeline_x) { SCOPED_TIMER(_acquire_runtime_filter_timer); std::vector vexprs; for (size_t i = 0; i < _runtime_filter_descs.size(); ++i) { - IRuntimeFilter* runtime_filter = _runtime_filter_ctxs[i].runtime_filter; + auto runtime_filter = _runtime_filter_ctxs[i].runtime_filter; if (pipeline_x) { runtime_filter->update_state(); if (runtime_filter->is_ready() && !_runtime_filter_ctxs[i].apply_mark) { diff --git a/be/src/pipeline/common/runtime_filter_consumer.h b/be/src/pipeline/common/runtime_filter_consumer.h index 9bee6053f6f7d5..4b500a916f0e47 100644 --- a/be/src/pipeline/common/runtime_filter_consumer.h +++ b/be/src/pipeline/common/runtime_filter_consumer.h @@ -17,6 +17,8 @@ #pragma once +#include + #include "exprs/runtime_filter.h" #include "pipeline/dependency.h" @@ -55,10 +57,10 @@ class RuntimeFilterConsumer { // For runtime filters struct RuntimeFilterContext { - RuntimeFilterContext(IRuntimeFilter* rf) : runtime_filter(rf) {} + RuntimeFilterContext(std::shared_ptr rf) : runtime_filter(std::move(rf)) {} // set to true if this runtime filter is already applied to vconjunct_ctx_ptr bool apply_mark = false; - IRuntimeFilter* runtime_filter = nullptr; + std::shared_ptr runtime_filter; pipeline::RuntimeFilterDependency* runtime_filter_dependency = nullptr; }; diff --git a/be/src/pipeline/exec/datagen_operator.cpp b/be/src/pipeline/exec/datagen_operator.cpp index 48e428ceef42cf..dae39f179a68f2 100644 --- a/be/src/pipeline/exec/datagen_operator.cpp +++ b/be/src/pipeline/exec/datagen_operator.cpp @@ -86,7 +86,7 @@ Status DataGenLocalState::init(RuntimeState* state, LocalStateInfo& info) { // TODO: use runtime filter to filte result block, maybe this node need derive from vscan_node. for (const auto& filter_desc : p._runtime_filter_descs) { - IRuntimeFilter* runtime_filter = nullptr; + std::shared_ptr runtime_filter; RETURN_IF_ERROR(state->register_consumer_runtime_filter( filter_desc, p.ignore_data_distribution(), p.node_id(), &runtime_filter)); runtime_filter->init_profile(_runtime_profile.get()); diff --git a/be/src/pipeline/exec/join_build_sink_operator.h b/be/src/pipeline/exec/join_build_sink_operator.h index d43a6d1bf9d6ef..714e0c34190678 100644 --- a/be/src/pipeline/exec/join_build_sink_operator.h +++ b/be/src/pipeline/exec/join_build_sink_operator.h @@ -28,7 +28,9 @@ class JoinBuildSinkLocalState : public PipelineXSinkLocalState public: Status init(RuntimeState* state, LocalSinkStateInfo& info) override; - const std::vector& runtime_filters() const { return _runtime_filters; } + const std::vector>& runtime_filters() const { + return _runtime_filters; + } protected: JoinBuildSinkLocalState(DataSinkOperatorXBase* parent, RuntimeState* state) @@ -41,7 +43,7 @@ class JoinBuildSinkLocalState : public PipelineXSinkLocalState RuntimeProfile::Counter* _publish_runtime_filter_timer = nullptr; RuntimeProfile::Counter* _runtime_filter_compute_timer = nullptr; RuntimeProfile::Counter* _runtime_filter_init_timer = nullptr; - std::vector _runtime_filters; + std::vector> _runtime_filters; }; template diff --git a/be/src/runtime/fragment_mgr.cpp b/be/src/runtime/fragment_mgr.cpp index d6bbba016be1df..057dca4a2ee18b 100644 --- a/be/src/runtime/fragment_mgr.cpp +++ b/be/src/runtime/fragment_mgr.cpp @@ -1054,7 +1054,6 @@ Status FragmentMgr::apply_filterv2(const PPublishFilterRequestV2* request, QueryThreadContext query_thread_context; RuntimeFilterMgr* runtime_filter_mgr = nullptr; - ObjectPool* pool = nullptr; const auto& fragment_instance_ids = request->fragment_instance_ids(); { @@ -1071,7 +1070,6 @@ Status FragmentMgr::apply_filterv2(const PPublishFilterRequestV2* request, DCHECK(pip_context != nullptr); runtime_filter_mgr = pip_context->get_query_ctx()->runtime_filter_mgr(); - pool = &pip_context->get_query_ctx()->obj_pool; query_thread_context = {pip_context->get_query_ctx()->query_id(), pip_context->get_query_ctx()->query_mem_tracker, pip_context->get_query_ctx()->workload_group()}; @@ -1089,13 +1087,13 @@ Status FragmentMgr::apply_filterv2(const PPublishFilterRequestV2* request, SCOPED_ATTACH_TASK(query_thread_context); // 1. get the target filters - std::vector filters; + std::vector> filters; RETURN_IF_ERROR(runtime_filter_mgr->get_consume_filters(request->filter_id(), filters)); // 2. create the filter wrapper to replace or ignore the target filters if (!filters.empty()) { - UpdateRuntimeFilterParamsV2 params {request, attach_data, pool, filters[0]->column_type()}; - RuntimePredicateWrapper* filter_wrapper = nullptr; + UpdateRuntimeFilterParamsV2 params {request, attach_data, filters[0]->column_type()}; + std::shared_ptr filter_wrapper; RETURN_IF_ERROR(IRuntimeFilter::create_wrapper(¶ms, &filter_wrapper)); std::ranges::for_each(filters, [&](auto& filter) { diff --git a/be/src/runtime/runtime_filter_mgr.cpp b/be/src/runtime/runtime_filter_mgr.cpp index 0e5b37c8ffa220..625b487d0ee1f3 100644 --- a/be/src/runtime/runtime_filter_mgr.cpp +++ b/be/src/runtime/runtime_filter_mgr.cpp @@ -58,8 +58,8 @@ RuntimeFilterMgr::~RuntimeFilterMgr() { _pool.clear(); } -Status RuntimeFilterMgr::get_consume_filters(const int filter_id, - std::vector& consumer_filters) { +Status RuntimeFilterMgr::get_consume_filters( + const int filter_id, std::vector>& consumer_filters) { std::lock_guard l(_lock); auto iter = _consumer_map.find(filter_id); if (iter == _consumer_map.end()) { @@ -74,7 +74,7 @@ Status RuntimeFilterMgr::get_consume_filters(const int filter_id, Status RuntimeFilterMgr::register_consumer_filter(const TRuntimeFilterDesc& desc, const TQueryOptions& options, int node_id, - IRuntimeFilter** consumer_filter, + std::shared_ptr* consumer_filter, bool build_bf_exactly, bool need_local_merge) { SCOPED_CONSUME_MEM_TRACKER(_tracker.get()); int32_t key = desc.filter_id; @@ -91,10 +91,10 @@ Status RuntimeFilterMgr::register_consumer_filter(const TRuntimeFilterDesc& desc } if (!has_exist) { - IRuntimeFilter* filter; - RETURN_IF_ERROR(IRuntimeFilter::create(_state, &_pool, &desc, &options, - RuntimeFilterRole::CONSUMER, node_id, &filter, - build_bf_exactly, need_local_merge)); + std::shared_ptr filter; + RETURN_IF_ERROR(IRuntimeFilter::create(_state, &desc, &options, RuntimeFilterRole::CONSUMER, + node_id, &filter, build_bf_exactly, + need_local_merge)); _consumer_map[key].emplace_back(node_id, filter); *consumer_filter = filter; } else if (!need_local_merge) { @@ -106,7 +106,7 @@ Status RuntimeFilterMgr::register_consumer_filter(const TRuntimeFilterDesc& desc Status RuntimeFilterMgr::register_local_merge_producer_filter( const doris::TRuntimeFilterDesc& desc, const doris::TQueryOptions& options, - doris::IRuntimeFilter** producer_filter, bool build_bf_exactly) { + std::shared_ptr* producer_filter, bool build_bf_exactly) { SCOPED_CONSUME_MEM_TRACKER(_tracker.get()); int32_t key = desc.filter_id; @@ -121,14 +121,13 @@ Status RuntimeFilterMgr::register_local_merge_producer_filter( } DCHECK(_state != nullptr); - RETURN_IF_ERROR(IRuntimeFilter::create(_state, &_pool, &desc, &options, - RuntimeFilterRole::PRODUCER, -1, producer_filter, - build_bf_exactly, true)); + RETURN_IF_ERROR(IRuntimeFilter::create(_state, &desc, &options, RuntimeFilterRole::PRODUCER, -1, + producer_filter, build_bf_exactly, true)); { std::lock_guard l(*iter->second.lock); if (iter->second.filters.empty()) { - IRuntimeFilter* merge_filter = nullptr; - RETURN_IF_ERROR(IRuntimeFilter::create(_state, &_pool, &desc, &options, + std::shared_ptr merge_filter; + RETURN_IF_ERROR(IRuntimeFilter::create(_state, &desc, &options, RuntimeFilterRole::PRODUCER, -1, &merge_filter, build_bf_exactly, true)); iter->second.filters.emplace_back(merge_filter); @@ -158,7 +157,7 @@ Status RuntimeFilterMgr::get_local_merge_producer_filters( Status RuntimeFilterMgr::register_producer_filter(const TRuntimeFilterDesc& desc, const TQueryOptions& options, - IRuntimeFilter** producer_filter, + std::shared_ptr* producer_filter, bool build_bf_exactly) { SCOPED_CONSUME_MEM_TRACKER(_tracker.get()); int32_t key = desc.filter_id; @@ -169,9 +168,8 @@ Status RuntimeFilterMgr::register_producer_filter(const TRuntimeFilterDesc& desc if (iter != _producer_map.end()) { return Status::InvalidArgument("filter has registed"); } - RETURN_IF_ERROR(IRuntimeFilter::create(_state, &_pool, &desc, &options, - RuntimeFilterRole::PRODUCER, -1, producer_filter, - build_bf_exactly)); + RETURN_IF_ERROR(IRuntimeFilter::create(_state, &desc, &options, RuntimeFilterRole::PRODUCER, -1, + producer_filter, build_bf_exactly)); _producer_map.emplace(key, *producer_filter); return Status::OK(); } @@ -179,9 +177,9 @@ Status RuntimeFilterMgr::register_producer_filter(const TRuntimeFilterDesc& desc Status RuntimeFilterMgr::update_filter(const PPublishFilterRequest* request, butil::IOBufAsZeroCopyInputStream* data) { SCOPED_CONSUME_MEM_TRACKER(_tracker.get()); - UpdateRuntimeFilterParams params(request, data, &_pool); + UpdateRuntimeFilterParams params(request, data); int filter_id = request->filter_id(); - std::vector filters; + std::vector> filters; // The code is organized for upgrade compatibility to prevent infinite waiting // old way update filter the code should be deleted after the upgrade is complete. { @@ -196,7 +194,7 @@ Status RuntimeFilterMgr::update_filter(const PPublishFilterRequest* request, } iter->second.clear(); } - for (auto* filter : filters) { + for (auto filter : filters) { RETURN_IF_ERROR(filter->update_filter(¶ms)); } @@ -233,8 +231,7 @@ Status RuntimeFilterMergeControllerEntity::_init_with_desc( cnt_val->runtime_filter_desc = *runtime_filter_desc; cnt_val->target_info = *target_info; cnt_val->pool.reset(new ObjectPool()); - cnt_val->filter = cnt_val->pool->add( - new IRuntimeFilter(_state, &_state->get_query_ctx()->obj_pool, runtime_filter_desc)); + cnt_val->filter = cnt_val->pool->add(new IRuntimeFilter(_state, runtime_filter_desc)); auto filter_id = runtime_filter_desc->filter_id; RETURN_IF_ERROR(cnt_val->filter->init_with_desc(&cnt_val->runtime_filter_desc, query_options, @@ -254,8 +251,7 @@ Status RuntimeFilterMergeControllerEntity::_init_with_desc( cnt_val->runtime_filter_desc = *runtime_filter_desc; cnt_val->targetv2_info = *targetv2_info; cnt_val->pool.reset(new ObjectPool()); - cnt_val->filter = cnt_val->pool->add( - new IRuntimeFilter(_state, &_state->get_query_ctx()->obj_pool, runtime_filter_desc)); + cnt_val->filter = cnt_val->pool->add(new IRuntimeFilter(_state, runtime_filter_desc)); auto filter_id = runtime_filter_desc->filter_id; RETURN_IF_ERROR(cnt_val->filter->init_with_desc(&cnt_val->runtime_filter_desc, query_options)); @@ -355,7 +351,7 @@ Status RuntimeFilterMergeControllerEntity::send_filter_size(const PSendFilterSiz } Status RuntimeFilterMgr::sync_filter_size(const PSyncFilterSizeRequest* request) { - auto* filter = try_get_product_filter(request->filter_id()); + auto filter = try_get_product_filter(request->filter_id()); if (filter) { filter->set_synced_size(request->filter_size()); return Status::OK(); @@ -397,9 +393,8 @@ Status RuntimeFilterMergeControllerEntity::merge(const PMergeFilterRequest* requ return Status::OK(); } MergeRuntimeFilterParams params(request, attach_data); - ObjectPool* pool = cnt_val->pool.get(); RuntimeFilterWrapperHolder holder; - RETURN_IF_ERROR(IRuntimeFilter::create_wrapper(¶ms, pool, holder.getHandle())); + RETURN_IF_ERROR(IRuntimeFilter::create_wrapper(¶ms, holder.getHandle())); auto st = cnt_val->filter->merge_from(holder.getHandle()->get()); if (!st) { diff --git a/be/src/runtime/runtime_filter_mgr.h b/be/src/runtime/runtime_filter_mgr.h index 9b0216e07786d6..d89a3b9f1b1768 100644 --- a/be/src/runtime/runtime_filter_mgr.h +++ b/be/src/runtime/runtime_filter_mgr.h @@ -59,7 +59,7 @@ struct LocalMergeFilters { int merge_time = 0; int merge_size_times = 0; uint64_t local_merged_size = 0; - std::vector filters; + std::vector> filters; }; /// producer: @@ -81,9 +81,10 @@ class RuntimeFilterMgr { ~RuntimeFilterMgr(); - Status get_consume_filters(const int filter_id, std::vector& consumer_filters); + Status get_consume_filters(const int filter_id, + std::vector>& consumer_filters); - IRuntimeFilter* try_get_product_filter(const int filter_id) { + std::shared_ptr try_get_product_filter(const int filter_id) { std::lock_guard l(_lock); auto iter = _producer_map.find(filter_id); if (iter == _producer_map.end()) { @@ -94,18 +95,18 @@ class RuntimeFilterMgr { // register filter Status register_consumer_filter(const TRuntimeFilterDesc& desc, const TQueryOptions& options, - int node_id, IRuntimeFilter** consumer_filter, + int node_id, std::shared_ptr* consumer_filter, bool build_bf_exactly = false, bool need_local_merge = false); Status register_local_merge_producer_filter(const TRuntimeFilterDesc& desc, const TQueryOptions& options, - IRuntimeFilter** producer_filter, + std::shared_ptr* producer_filter, bool build_bf_exactly = false); Status get_local_merge_producer_filters(int filter_id, LocalMergeFilters** local_merge_filters); Status register_producer_filter(const TRuntimeFilterDesc& desc, const TQueryOptions& options, - IRuntimeFilter** producer_filter, + std::shared_ptr* producer_filter, bool build_bf_exactly = false); // update filter by remote @@ -121,13 +122,13 @@ class RuntimeFilterMgr { private: struct ConsumerFilterHolder { int node_id; - IRuntimeFilter* filter = nullptr; + std::shared_ptr filter; }; // RuntimeFilterMgr is owned by RuntimeState, so we only // use filter_id as key // key: "filter-id" std::map> _consumer_map; - std::map _producer_map; + std::map> _producer_map; std::map _local_merge_producer_map; RuntimeFilterParamsContext* _state = nullptr; diff --git a/be/src/runtime/runtime_state.cpp b/be/src/runtime/runtime_state.cpp index 5471a01c246346..34aa457d5a6afb 100644 --- a/be/src/runtime/runtime_state.cpp +++ b/be/src/runtime/runtime_state.cpp @@ -533,10 +533,9 @@ RuntimeFilterMgr* RuntimeState::global_runtime_filter_mgr() { return _query_ctx->runtime_filter_mgr(); } -Status RuntimeState::register_producer_runtime_filter(const doris::TRuntimeFilterDesc& desc, - bool need_local_merge, - doris::IRuntimeFilter** producer_filter, - bool build_bf_exactly) { +Status RuntimeState::register_producer_runtime_filter( + const TRuntimeFilterDesc& desc, bool need_local_merge, + std::shared_ptr* producer_filter, bool build_bf_exactly) { if (desc.has_remote_targets || need_local_merge) { return global_runtime_filter_mgr()->register_local_merge_producer_filter( desc, query_options(), producer_filter, build_bf_exactly); @@ -546,9 +545,9 @@ Status RuntimeState::register_producer_runtime_filter(const doris::TRuntimeFilte } } -Status RuntimeState::register_consumer_runtime_filter(const doris::TRuntimeFilterDesc& desc, - bool need_local_merge, int node_id, - doris::IRuntimeFilter** consumer_filter) { +Status RuntimeState::register_consumer_runtime_filter( + const doris::TRuntimeFilterDesc& desc, bool need_local_merge, int node_id, + std::shared_ptr* consumer_filter) { if (desc.has_remote_targets || need_local_merge) { return global_runtime_filter_mgr()->register_consumer_filter(desc, query_options(), node_id, consumer_filter, false, true); diff --git a/be/src/runtime/runtime_state.h b/be/src/runtime/runtime_state.h index ec812fffed8066..e3f8078156fc7e 100644 --- a/be/src/runtime/runtime_state.h +++ b/be/src/runtime/runtime_state.h @@ -561,12 +561,12 @@ class RuntimeState { Status register_producer_runtime_filter(const doris::TRuntimeFilterDesc& desc, bool need_local_merge, - doris::IRuntimeFilter** producer_filter, + std::shared_ptr* producer_filter, bool build_bf_exactly); Status register_consumer_runtime_filter(const doris::TRuntimeFilterDesc& desc, bool need_local_merge, int node_id, - doris::IRuntimeFilter** producer_filter); + std::shared_ptr* producer_filter); bool is_nereids() const; bool enable_join_spill() const { diff --git a/be/test/exprs/runtime_filter_test.cpp b/be/test/exprs/runtime_filter_test.cpp index 36d7cd885dd50c..cfcbaae4a4e6aa 100644 --- a/be/test/exprs/runtime_filter_test.cpp +++ b/be/test/exprs/runtime_filter_test.cpp @@ -48,8 +48,10 @@ class RuntimeFilterTest : public testing::Test { // std::unique_ptr _runtime_filter; }; -IRuntimeFilter* create_runtime_filter(TRuntimeFilterType::type type, TQueryOptions* options, - RuntimeState* _runtime_stat, ObjectPool* _obj_pool) { +std::shared_ptr create_runtime_filter(TRuntimeFilterType::type type, + TQueryOptions* options, + RuntimeState* _runtime_stat, + ObjectPool* _obj_pool) { TRuntimeFilterDesc desc; desc.__set_filter_id(0); desc.__set_expr_order(0); @@ -96,10 +98,10 @@ IRuntimeFilter* create_runtime_filter(TRuntimeFilterType::type type, TQueryOptio desc.__set_planId_to_target_expr(planid_to_target_expr); } - IRuntimeFilter* runtime_filter = nullptr; - Status status = IRuntimeFilter::create(RuntimeFilterParamsContext::create(_runtime_stat), - _obj_pool, &desc, options, RuntimeFilterRole::PRODUCER, - -1, &runtime_filter); + std::shared_ptr runtime_filter; + Status status = + IRuntimeFilter::create(RuntimeFilterParamsContext::create(_runtime_stat), &desc, + options, RuntimeFilterRole::PRODUCER, -1, &runtime_filter); EXPECT_TRUE(status.ok()) << status.to_string();