From 87ea5244a60c215d5a737e99301ba292a54dddba Mon Sep 17 00:00:00 2001 From: Mryange <59914473+Mryange@users.noreply.github.com> Date: Thu, 22 Feb 2024 10:01:56 +0800 Subject: [PATCH] [refine](pipelinex) get sink local state does not require an id. #31195 --- be/src/pipeline/exec/hashjoin_build_sink.h | 2 +- be/src/pipeline/pipeline_x/operator.cpp | 5 ++--- be/src/pipeline/pipeline_x/operator.h | 10 ++++++---- be/src/pipeline/pipeline_x/pipeline_x_task.cpp | 8 ++++---- be/src/runtime/runtime_state.cpp | 6 +++--- be/src/runtime/runtime_state.h | 4 ++-- 6 files changed, 18 insertions(+), 17 deletions(-) diff --git a/be/src/pipeline/exec/hashjoin_build_sink.h b/be/src/pipeline/exec/hashjoin_build_sink.h index eb78d6bf515050..080097d8bc44d7 100644 --- a/be/src/pipeline/exec/hashjoin_build_sink.h +++ b/be/src/pipeline/exec/hashjoin_build_sink.h @@ -139,7 +139,7 @@ class HashJoinBuildSinkOperatorX final SourceState source_state) override; bool should_dry_run(RuntimeState* state) override { - return _is_broadcast_join && !state->get_sink_local_state(operator_id()) + return _is_broadcast_join && !state->get_sink_local_state() ->cast() ._should_build_hash_table; } diff --git a/be/src/pipeline/pipeline_x/operator.cpp b/be/src/pipeline/pipeline_x/operator.cpp index 101b02cf4064b9..19cda4d51adda9 100644 --- a/be/src/pipeline/pipeline_x/operator.cpp +++ b/be/src/pipeline/pipeline_x/operator.cpp @@ -235,7 +235,7 @@ std::string DataSinkOperatorXBase::debug_string(int indentation_level) const { } std::string DataSinkOperatorXBase::debug_string(RuntimeState* state, int indentation_level) const { - return state->get_sink_local_state(operator_id())->debug_string(indentation_level); + return state->get_sink_local_state()->debug_string(indentation_level); } Status DataSinkOperatorXBase::init(const TDataSink& tsink) { @@ -498,8 +498,7 @@ Status StreamingOperatorX::get_block(RuntimeState* state, vector template Status StatefulOperatorX::get_block(RuntimeState* state, vectorized::Block* block, SourceState& source_state) { - auto& local_state = state->get_local_state(OperatorX::operator_id()) - ->template cast(); + auto& local_state = get_local_state(state); if (need_more_input_data(state)) { local_state._child_block->clear_column_data(); RETURN_IF_ERROR(OperatorX::_child_x->get_block_after_projects( diff --git a/be/src/pipeline/pipeline_x/operator.h b/be/src/pipeline/pipeline_x/operator.h index 87101de0612658..387fddef059209 100644 --- a/be/src/pipeline/pipeline_x/operator.h +++ b/be/src/pipeline/pipeline_x/operator.h @@ -536,8 +536,8 @@ class DataSinkOperatorXBase : public OperatorBase { [[nodiscard]] bool is_source() const override { return false; } - Status close(RuntimeState* state, Status exec_status) { - auto result = state->get_sink_local_state_result(operator_id()); + static Status close(RuntimeState* state, Status exec_status) { + auto result = state->get_sink_local_state_result(); if (!result) { return result.error(); } @@ -600,7 +600,7 @@ class DataSinkOperatorX : public DataSinkOperatorXBase { using LocalState = LocalStateType; [[nodiscard]] LocalState& get_local_state(RuntimeState* state) const { - return state->get_sink_local_state(operator_id())->template cast(); + return state->get_sink_local_state()->template cast(); } }; @@ -663,8 +663,10 @@ class StatefulOperatorX : public OperatorX { : OperatorX(pool, tnode, operator_id, descs) {} virtual ~StatefulOperatorX() = default; + using OperatorX::get_local_state; + [[nodiscard]] Status get_block(RuntimeState* state, vectorized::Block* block, - SourceState& source_state) override; + SourceState& source_state) final; [[nodiscard]] virtual Status pull(RuntimeState* state, vectorized::Block* block, SourceState& source_state) const = 0; diff --git a/be/src/pipeline/pipeline_x/pipeline_x_task.cpp b/be/src/pipeline/pipeline_x/pipeline_x_task.cpp index 7579789534fdbf..d4117661a7933e 100644 --- a/be/src/pipeline/pipeline_x/pipeline_x_task.cpp +++ b/be/src/pipeline/pipeline_x/pipeline_x_task.cpp @@ -91,9 +91,9 @@ Status PipelineXTask::prepare(const TPipelineInstanceParams& local_params, const std::vector no_scan_ranges; auto scan_ranges = find_with_default(local_params.per_node_scan_ranges, _operators.front()->node_id(), no_scan_ranges); - auto* parent_profile = _state->get_sink_local_state(_sink->operator_id())->profile(); + auto* parent_profile = _state->get_sink_local_state()->profile(); query_ctx->register_query_statistics( - _state->get_sink_local_state(_sink->operator_id())->get_query_statistics_ptr()); + _state->get_sink_local_state()->get_query_statistics_ptr()); for (int op_idx = _operators.size() - 1; op_idx >= 0; op_idx--) { auto& op = _operators[op_idx]; @@ -135,7 +135,7 @@ Status PipelineXTask::_extract_dependencies() { } } { - auto* local_state = _state->get_sink_local_state(_sink->operator_id()); + auto* local_state = _state->get_sink_local_state(); auto* dep = local_state->dependency(); DCHECK(dep != nullptr); _write_dependencies = dep; @@ -206,7 +206,7 @@ Status PipelineXTask::_open() { RETURN_IF_ERROR(st); } } - RETURN_IF_ERROR(_state->get_sink_local_state(_sink->operator_id())->open(_state)); + RETURN_IF_ERROR(_state->get_sink_local_state()->open(_state)); _opened = true; return Status::OK(); } diff --git a/be/src/runtime/runtime_state.cpp b/be/src/runtime/runtime_state.cpp index 8762366fe04520..5a32b37c40fde8 100644 --- a/be/src/runtime/runtime_state.cpp +++ b/be/src/runtime/runtime_state.cpp @@ -486,13 +486,13 @@ void RuntimeState::emplace_sink_local_state( _sink_local_state = std::move(state); } -doris::pipeline::PipelineXSinkLocalStateBase* RuntimeState::get_sink_local_state(int) { +doris::pipeline::PipelineXSinkLocalStateBase* RuntimeState::get_sink_local_state() { return _sink_local_state.get(); } -Result RuntimeState::get_sink_local_state_result(int id) { +Result RuntimeState::get_sink_local_state_result() { if (!_sink_local_state) { - return ResultError(Status::InternalError("_op_id_to_sink_local_state id:{} is null", id)); + return ResultError(Status::InternalError("_op_id_to_sink_local_state not exist")); } return _sink_local_state.get(); } diff --git a/be/src/runtime/runtime_state.h b/be/src/runtime/runtime_state.h index 03b518e5c344a1..bd54ca98e668ef 100644 --- a/be/src/runtime/runtime_state.h +++ b/be/src/runtime/runtime_state.h @@ -552,9 +552,9 @@ class RuntimeState { void emplace_sink_local_state(int id, std::unique_ptr state); - SinkLocalState* get_sink_local_state(int id); + SinkLocalState* get_sink_local_state(); - Result get_sink_local_state_result(int id); + Result get_sink_local_state_result(); void resize_op_id_to_local_state(int operator_size);