diff --git a/dbms/src/DataStreams/TiRemoteBlockInputStream.h b/dbms/src/DataStreams/TiRemoteBlockInputStream.h index 124f08d65c4..fc4b44b4ac8 100644 --- a/dbms/src/DataStreams/TiRemoteBlockInputStream.h +++ b/dbms/src/DataStreams/TiRemoteBlockInputStream.h @@ -171,7 +171,7 @@ class TiRemoteBlockInputStream : public IProfilingBlockInputStream protected: void readSuffixImpl() override { - LOG_DEBUG(log, "finish read {} rows from remote", total_rows); + LOG_INFO(log, "finish read {} rows from remote", total_rows); } void appendInfo(FmtBuffer & buffer) const override diff --git a/dbms/src/Debug/MockExecutor/ExchangeReceiverBinder.cpp b/dbms/src/Debug/MockExecutor/ExchangeReceiverBinder.cpp index 21d4d649ffe..580870386e2 100644 --- a/dbms/src/Debug/MockExecutor/ExchangeReceiverBinder.cpp +++ b/dbms/src/Debug/MockExecutor/ExchangeReceiverBinder.cpp @@ -16,6 +16,7 @@ #include #include #include +#include #include #include #include @@ -29,6 +30,9 @@ bool ExchangeReceiverBinder::toTiPBExecutor(tipb::Executor * tipb_executor, int3 tipb_executor->set_fine_grained_shuffle_stream_count(fine_grained_shuffle_stream_count); tipb::ExchangeReceiver * exchange_receiver = tipb_executor->mutable_exchange_receiver(); + if (exchange_sender) + exchange_receiver->set_tp(exchange_sender->getType()); + for (auto & field : output_schema) { auto tipb_type = TiDB::columnInfoToFieldType(field.second); @@ -61,9 +65,16 @@ bool ExchangeReceiverBinder::toTiPBExecutor(tipb::Executor * tipb_executor, int3 } -ExecutorBinderPtr compileExchangeReceiver(size_t & executor_index, DAGSchema schema, uint64_t fine_grained_shuffle_stream_count) +void ExchangeReceiverBinder::toMPPSubPlan(size_t & executor_index, const DAGProperties & properties, std::unordered_map, std::shared_ptr>> & exchange_map) +{ + RUNTIME_CHECK_MSG(exchange_sender, "exchange_sender must not be nullptr in toMPPSubPlan"); + exchange_sender->toMPPSubPlan(executor_index, properties, exchange_map); + exchange_map[name] = std::make_pair(shared_from_this(), exchange_sender); +} + +ExecutorBinderPtr compileExchangeReceiver(size_t & executor_index, DAGSchema schema, uint64_t fine_grained_shuffle_stream_count, const std::shared_ptr & exchange_sender) { - ExecutorBinderPtr exchange_receiver = std::make_shared(executor_index, schema, fine_grained_shuffle_stream_count); + ExecutorBinderPtr exchange_receiver = std::make_shared(executor_index, schema, fine_grained_shuffle_stream_count, exchange_sender); return exchange_receiver; } } // namespace DB::mock diff --git a/dbms/src/Debug/MockExecutor/ExchangeReceiverBinder.h b/dbms/src/Debug/MockExecutor/ExchangeReceiverBinder.h index c2327c87861..ec0cf0d956e 100644 --- a/dbms/src/Debug/MockExecutor/ExchangeReceiverBinder.h +++ b/dbms/src/Debug/MockExecutor/ExchangeReceiverBinder.h @@ -20,21 +20,30 @@ namespace DB::mock { class ExchangeReceiverBinder : public ExecutorBinder + , public std::enable_shared_from_this { public: - ExchangeReceiverBinder(size_t & index, const DAGSchema & output, uint64_t fine_grained_shuffle_stream_count_ = 0) + ExchangeReceiverBinder( + size_t & index, + const DAGSchema & output, + uint64_t fine_grained_shuffle_stream_count_ = 0, + const std::shared_ptr & exchange_sender_ = nullptr) : ExecutorBinder(index, "exchange_receiver_" + std::to_string(index), output) , fine_grained_shuffle_stream_count(fine_grained_shuffle_stream_count_) + , exchange_sender(exchange_sender_) {} bool toTiPBExecutor(tipb::Executor * tipb_executor, int32_t collator_id, const MPPInfo & mpp_info, const Context &) override; void columnPrune(std::unordered_set &) override {} + void toMPPSubPlan(size_t & executor_index, const DAGProperties &, std::unordered_map, std::shared_ptr>> & exchange_map) override; + private: TaskMetas task_metas; uint64_t fine_grained_shuffle_stream_count; + std::shared_ptr exchange_sender; }; -ExecutorBinderPtr compileExchangeReceiver(size_t & executor_index, DAGSchema schema, uint64_t fine_grained_shuffle_stream_count); +ExecutorBinderPtr compileExchangeReceiver(size_t & executor_index, DAGSchema schema, uint64_t fine_grained_shuffle_stream_count, const std::shared_ptr & exchange_sender); } // namespace DB::mock diff --git a/dbms/src/Flash/Coprocessor/DAGQueryBlockInterpreter.cpp b/dbms/src/Flash/Coprocessor/DAGQueryBlockInterpreter.cpp index d0716daddbc..f3c73634740 100644 --- a/dbms/src/Flash/Coprocessor/DAGQueryBlockInterpreter.cpp +++ b/dbms/src/Flash/Coprocessor/DAGQueryBlockInterpreter.cpp @@ -757,7 +757,6 @@ void DAGQueryBlockInterpreter::executeExpand(DAGPipeline & pipeline, const Expre void DAGQueryBlockInterpreter::handleExchangeSender(DAGPipeline & pipeline) { - RUNTIME_ASSERT(dagContext().isMPPTask() && dagContext().tunnel_set != nullptr, log, "exchange_sender only run in MPP"); /// exchange sender should be at the top of operators const auto & exchange_sender = query_block.exchange_sender->exchange_sender(); std::vector partition_col_ids = ExchangeSenderInterpreterHelper::genPartitionColIds(exchange_sender); @@ -776,7 +775,6 @@ void DAGQueryBlockInterpreter::handleExchangeSender(DAGPipeline & pipeline) pipeline.transform([&](auto & stream) { // construct writer std::unique_ptr response_writer = newMPPExchangeWriter( - dagContext().tunnel_set, partition_col_ids, partition_col_collators, exchange_sender.tp(), @@ -787,7 +785,8 @@ void DAGQueryBlockInterpreter::handleExchangeSender(DAGPipeline & pipeline) stream_count, batch_size, exchange_sender.compression(), - context.getSettingsRef().batch_send_min_limit_compression); + context.getSettingsRef().batch_send_min_limit_compression, + log->identifier()); stream = std::make_shared(stream, std::move(response_writer), log->identifier()); stream->setExtraInfo(extra_info); }); diff --git a/dbms/src/Flash/Coprocessor/DAGResponseWriter.h b/dbms/src/Flash/Coprocessor/DAGResponseWriter.h index 3ca0580482b..f5356ae1661 100644 --- a/dbms/src/Flash/Coprocessor/DAGResponseWriter.h +++ b/dbms/src/Flash/Coprocessor/DAGResponseWriter.h @@ -30,6 +30,14 @@ class DAGResponseWriter /// prepared with sample block virtual void prepare(const Block &){}; virtual void write(const Block & block) = 0; + + // For async writer, `isReadyForWrite` need to be called before calling `write`. + // ``` + // while (!isReadyForWrite()) {} + // write(block); + // ``` + virtual bool isReadyForWrite() const { throw Exception("Unsupport"); } + /// flush cached blocks for batch writer virtual void flush() = 0; virtual ~DAGResponseWriter() = default; diff --git a/dbms/src/Flash/Coprocessor/StreamWriter.h b/dbms/src/Flash/Coprocessor/StreamWriter.h index e52b7f06991..8cf498694e5 100644 --- a/dbms/src/Flash/Coprocessor/StreamWriter.h +++ b/dbms/src/Flash/Coprocessor/StreamWriter.h @@ -55,6 +55,7 @@ struct StreamWriter if (!writer->Write(resp)) throw Exception("Failed to write resp"); } + bool isReadyForWrite() const { throw Exception("Unsupport async write"); } }; using StreamWriterPtr = std::shared_ptr; diff --git a/dbms/src/Flash/Coprocessor/StreamingDAGResponseWriter.cpp b/dbms/src/Flash/Coprocessor/StreamingDAGResponseWriter.cpp index feb50a17ada..abea17b9f0e 100644 --- a/dbms/src/Flash/Coprocessor/StreamingDAGResponseWriter.cpp +++ b/dbms/src/Flash/Coprocessor/StreamingDAGResponseWriter.cpp @@ -20,7 +20,7 @@ #include #include #include -#include +#include namespace DB { @@ -68,6 +68,12 @@ void StreamingDAGResponseWriter::flush() encodeThenWriteBlocks(); } +template +bool StreamingDAGResponseWriter::isReadyForWrite() const +{ + return writer->isReadyForWrite(); +} + template void StreamingDAGResponseWriter::write(const Block & block) { @@ -141,5 +147,6 @@ void StreamingDAGResponseWriter::encodeThenWriteBlocks() } template class StreamingDAGResponseWriter; -template class StreamingDAGResponseWriter; +template class StreamingDAGResponseWriter; +template class StreamingDAGResponseWriter; } // namespace DB diff --git a/dbms/src/Flash/Coprocessor/StreamingDAGResponseWriter.h b/dbms/src/Flash/Coprocessor/StreamingDAGResponseWriter.h index 16e2014cbd5..6d2780f4d1e 100644 --- a/dbms/src/Flash/Coprocessor/StreamingDAGResponseWriter.h +++ b/dbms/src/Flash/Coprocessor/StreamingDAGResponseWriter.h @@ -37,6 +37,7 @@ class StreamingDAGResponseWriter : public DAGResponseWriter Int64 batch_send_min_limit_, DAGContext & dag_context_); void write(const Block & block) override; + bool isReadyForWrite() const override; void flush() override; private: diff --git a/dbms/src/Flash/Coprocessor/tests/gtest_streaming_writer.cpp b/dbms/src/Flash/Coprocessor/tests/gtest_streaming_writer.cpp index 1b429184996..4cb03559d79 100644 --- a/dbms/src/Flash/Coprocessor/tests/gtest_streaming_writer.cpp +++ b/dbms/src/Flash/Coprocessor/tests/gtest_streaming_writer.cpp @@ -101,6 +101,7 @@ struct MockStreamWriter {} void write(tipb::SelectResponse & response) { checker(response); } + bool isReadyForWrite() const { throw Exception("Unsupport async write"); } private: MockStreamWriterChecker checker; diff --git a/dbms/src/Flash/Coprocessor/tests/gtest_ti_remote_block_inputstream.cpp b/dbms/src/Flash/Coprocessor/tests/gtest_ti_remote_block_inputstream.cpp index d6d1097d720..6e5dc744628 100644 --- a/dbms/src/Flash/Coprocessor/tests/gtest_ti_remote_block_inputstream.cpp +++ b/dbms/src/Flash/Coprocessor/tests/gtest_ti_remote_block_inputstream.cpp @@ -118,16 +118,12 @@ struct MockWriter tracked_packet->serializeByResponse(response); queue->push(tracked_packet); } - void sendExecutionSummary(const tipb::SelectResponse & response) - { - tipb::SelectResponse tmp = response; - write(tmp); - } uint16_t getPartitionNum() const { return 1; } bool isLocal(size_t index) const { return index == 0; } + bool isReadyForWrite() const { throw Exception("Unsupport async write"); } std::vector result_field_types; @@ -371,7 +367,8 @@ class TestTiRemoteBlockInputStream : public testing::Test // 3. send execution summary writer->add_summary = true; ExecutionSummaryCollector summary_collector(*dag_context_ptr); - writer->sendExecutionSummary(summary_collector.genExecutionSummaryResponse()); + auto summary_response = summary_collector.genExecutionSummaryResponse(); + writer->write(summary_response); } void prepareQueueV2( diff --git a/dbms/src/Flash/Mpp/BroadcastOrPassThroughWriter.cpp b/dbms/src/Flash/Mpp/BroadcastOrPassThroughWriter.cpp index 11dae1896d8..01df6814dec 100644 --- a/dbms/src/Flash/Mpp/BroadcastOrPassThroughWriter.cpp +++ b/dbms/src/Flash/Mpp/BroadcastOrPassThroughWriter.cpp @@ -16,7 +16,7 @@ #include #include #include -#include +#include namespace DB { @@ -40,6 +40,12 @@ void BroadcastOrPassThroughWriter::flush() writeBlocks(); } +template +bool BroadcastOrPassThroughWriter::isReadyForWrite() const +{ + return writer->isReadyForWrite(); +} + template void BroadcastOrPassThroughWriter::write(const Block & block) { @@ -68,6 +74,6 @@ void BroadcastOrPassThroughWriter::writeBlocks() rows_in_blocks = 0; } -template class BroadcastOrPassThroughWriter; - +template class BroadcastOrPassThroughWriter; +template class BroadcastOrPassThroughWriter; } // namespace DB diff --git a/dbms/src/Flash/Mpp/BroadcastOrPassThroughWriter.h b/dbms/src/Flash/Mpp/BroadcastOrPassThroughWriter.h index a272ec7f1a4..296fda38ba9 100644 --- a/dbms/src/Flash/Mpp/BroadcastOrPassThroughWriter.h +++ b/dbms/src/Flash/Mpp/BroadcastOrPassThroughWriter.h @@ -32,6 +32,7 @@ class BroadcastOrPassThroughWriter : public DAGResponseWriter Int64 batch_send_min_limit_, DAGContext & dag_context_); void write(const Block & block) override; + bool isReadyForWrite() const override; void flush() override; private: diff --git a/dbms/src/Flash/Mpp/ExchangeReceiver.cpp b/dbms/src/Flash/Mpp/ExchangeReceiver.cpp index 43fd7b721d4..01246e9334a 100644 --- a/dbms/src/Flash/Mpp/ExchangeReceiver.cpp +++ b/dbms/src/Flash/Mpp/ExchangeReceiver.cpp @@ -707,34 +707,95 @@ DecodeDetail ExchangeReceiverBase::decodeChunks( } template -ExchangeReceiverResult ExchangeReceiverBase::nextResult( - std::queue & block_queue, - const Block & header, +ReceiveResult ExchangeReceiverBase::receive(size_t stream_id) +{ + return receive( + stream_id, + [&](size_t stream_id, std::shared_ptr & recv_msg) { + return msg_channels[stream_id]->pop(recv_msg); + }); +} + +template +ReceiveResult ExchangeReceiverBase::nonBlockingReceive(size_t stream_id) +{ + return receive( + stream_id, + [&](size_t stream_id, std::shared_ptr & recv_msg) { + return msg_channels[stream_id]->tryPop(recv_msg); + }); +} + +template +ReceiveResult ExchangeReceiverBase::receive( size_t stream_id, - std::unique_ptr & decoder_ptr) + std::function &)> recv_func) { if (unlikely(stream_id >= msg_channels.size())) { - LOG_ERROR(exc_log, "stream_id out of range, stream_id: {}, total_stream_count: {}", stream_id, msg_channels.size()); - return ExchangeReceiverResult::newError(0, "", "stream_id out of range"); + auto err_msg = fmt::format("stream_id out of range, stream_id: {}, total_channel_count: {}", stream_id, msg_channels.size()); + LOG_ERROR(exc_log, err_msg); + throw Exception(err_msg); } std::shared_ptr recv_msg; - if (msg_channels[stream_id]->pop(recv_msg) != MPMCQueueResult::OK) + switch (recv_func(stream_id, recv_msg)) { - return handleUnnormalChannel(block_queue, decoder_ptr); + case MPMCQueueResult::OK: + assert(recv_msg); + return {ReceiveStatus::ok, std::move(recv_msg)}; + case MPMCQueueResult::EMPTY: + return {ReceiveStatus::empty, nullptr}; + default: + return {ReceiveStatus::eof, nullptr}; } - else +} + +template +ExchangeReceiverResult ExchangeReceiverBase::toExchangeReceiveResult( + ReceiveResult & recv_result, + std::queue & block_queue, + const Block & header, + std::unique_ptr & decoder_ptr) +{ + switch (recv_result.recv_status) + { + case ReceiveStatus::ok: { - assert(recv_msg != nullptr); - if (unlikely(recv_msg->error_ptr != nullptr)) - return ExchangeReceiverResult::newError(recv_msg->source_index, recv_msg->req_info, recv_msg->error_ptr->msg()); + assert(recv_result.recv_msg != nullptr); + if (unlikely(recv_result.recv_msg->error_ptr != nullptr)) + return ExchangeReceiverResult::newError( + recv_result.recv_msg->source_index, + recv_result.recv_msg->req_info, + recv_result.recv_msg->error_ptr->msg()); - ExchangeReceiverMetric::subDataSizeMetric(data_size_in_queue, recv_msg->packet->getPacket().ByteSizeLong()); - return toDecodeResult(block_queue, header, recv_msg, decoder_ptr); + ExchangeReceiverMetric::subDataSizeMetric( + data_size_in_queue, + recv_result.recv_msg->packet->getPacket().ByteSizeLong()); + return toDecodeResult(block_queue, header, recv_result.recv_msg, decoder_ptr); + } + case ReceiveStatus::eof: + return handleUnnormalChannel(block_queue, decoder_ptr); + case ReceiveStatus::empty: + throw Exception("Unexpected recv status: empty"); } } +template +ExchangeReceiverResult ExchangeReceiverBase::nextResult( + std::queue & block_queue, + const Block & header, + size_t stream_id, + std::unique_ptr & decoder_ptr) +{ + auto recv_res = receive(stream_id); + return toExchangeReceiveResult( + recv_res, + block_queue, + header, + decoder_ptr); +} + template ExchangeReceiverResult ExchangeReceiverBase::handleUnnormalChannel( std::queue & block_queue, diff --git a/dbms/src/Flash/Mpp/ExchangeReceiver.h b/dbms/src/Flash/Mpp/ExchangeReceiver.h index 8cd72e1c4cb..c395b5cafcb 100644 --- a/dbms/src/Flash/Mpp/ExchangeReceiver.h +++ b/dbms/src/Flash/Mpp/ExchangeReceiver.h @@ -82,6 +82,19 @@ enum class ExchangeReceiverState CLOSED, }; +enum class ReceiveStatus +{ + empty, + ok, + eof, +}; + +struct ReceiveResult +{ + ReceiveStatus recv_status; + std::shared_ptr recv_msg; +}; + template class ExchangeReceiverBase { @@ -105,6 +118,15 @@ class ExchangeReceiverBase void cancel(); void close(); + ReceiveResult receive(size_t stream_id); + ReceiveResult nonBlockingReceive(size_t stream_id); + + ExchangeReceiverResult toExchangeReceiveResult( + ReceiveResult & recv_result, + std::queue & block_queue, + const Block & header, + std::unique_ptr & decoder_ptr); + ExchangeReceiverResult nextResult( std::queue & block_queue, const Block & header, @@ -159,6 +181,10 @@ class ExchangeReceiverBase const std::shared_ptr & recv_msg, std::unique_ptr & decoder_ptr); + ReceiveResult receive( + size_t stream_id, + std::function &)> recv_func); + private: void prepareMsgChannels(); void addLocalConnectionNum(); diff --git a/dbms/src/Flash/Mpp/FineGrainedShuffleWriter.cpp b/dbms/src/Flash/Mpp/FineGrainedShuffleWriter.cpp index f24d59942fd..70e6ba5cfbc 100644 --- a/dbms/src/Flash/Mpp/FineGrainedShuffleWriter.cpp +++ b/dbms/src/Flash/Mpp/FineGrainedShuffleWriter.cpp @@ -18,7 +18,7 @@ #include #include #include -#include +#include namespace DB { @@ -95,6 +95,12 @@ void FineGrainedShuffleWriter::flush() batchWriteFineGrainedShuffle(); } +template +bool FineGrainedShuffleWriter::isReadyForWrite() const +{ + return writer->isReadyForWrite(); +} + template void FineGrainedShuffleWriter::write(const Block & block) { @@ -192,6 +198,7 @@ void FineGrainedShuffleWriter::batchWriteFineGrainedShuffle() } } -template class FineGrainedShuffleWriter; +template class FineGrainedShuffleWriter; +template class FineGrainedShuffleWriter; } // namespace DB diff --git a/dbms/src/Flash/Mpp/FineGrainedShuffleWriter.h b/dbms/src/Flash/Mpp/FineGrainedShuffleWriter.h index 44b26dfc2ae..cc94f6ebf58 100644 --- a/dbms/src/Flash/Mpp/FineGrainedShuffleWriter.h +++ b/dbms/src/Flash/Mpp/FineGrainedShuffleWriter.h @@ -45,6 +45,7 @@ class FineGrainedShuffleWriter : public DAGResponseWriter tipb::CompressionMode compression_mode_); void prepare(const Block & sample_block) override; void write(const Block & block) override; + bool isReadyForWrite() const override; void flush() override; private: diff --git a/dbms/src/Flash/Mpp/GRPCSendQueue.h b/dbms/src/Flash/Mpp/GRPCSendQueue.h index 70c63b03fbd..8ba3d269541 100644 --- a/dbms/src/Flash/Mpp/GRPCSendQueue.h +++ b/dbms/src/Flash/Mpp/GRPCSendQueue.h @@ -120,6 +120,16 @@ class GRPCSendQueue return ret; } + bool nonBlockingPush(T && data) + { + auto ret = send_queue.nonBlockingPush(std::move(data)) == MPMCQueueResult::OK; + if (ret) + { + kickCompletionQueue(); + } + return ret; + } + /// Cancel the send queue, and set the cancel reason bool cancelWith(const String & reason) { @@ -193,6 +203,11 @@ class GRPCSendQueue return ret; } + bool isFull() const + { + return send_queue.isFull(); + } + private: friend class tests::TestGRPCSendQueue; diff --git a/dbms/src/Flash/Mpp/HashPartitionWriter.cpp b/dbms/src/Flash/Mpp/HashPartitionWriter.cpp index a2f5c9c1c2d..9646e6a1ad4 100644 --- a/dbms/src/Flash/Mpp/HashPartitionWriter.cpp +++ b/dbms/src/Flash/Mpp/HashPartitionWriter.cpp @@ -18,7 +18,7 @@ #include #include #include -#include +#include namespace DB { @@ -91,6 +91,12 @@ void HashPartitionWriter::flush() } } +template +bool HashPartitionWriter::isReadyForWrite() const +{ + return writer->isReadyForWrite(); +} + template void HashPartitionWriter::writeImplV1(const Block & block) { @@ -236,6 +242,7 @@ void HashPartitionWriter::writePartitionBlocks(std::vector; +template class HashPartitionWriter; +template class HashPartitionWriter; } // namespace DB diff --git a/dbms/src/Flash/Mpp/HashPartitionWriter.h b/dbms/src/Flash/Mpp/HashPartitionWriter.h index 096e6df465c..12a22705437 100644 --- a/dbms/src/Flash/Mpp/HashPartitionWriter.h +++ b/dbms/src/Flash/Mpp/HashPartitionWriter.h @@ -38,6 +38,7 @@ class HashPartitionWriter : public DAGResponseWriter MPPDataPacketVersion data_codec_version_, tipb::CompressionMode compression_mode_); void write(const Block & block) override; + bool isReadyForWrite() const override; void flush() override; private: diff --git a/dbms/src/Flash/Mpp/LocalRequestHandler.h b/dbms/src/Flash/Mpp/LocalRequestHandler.h index f4c2364c33f..3bc13a7f5ba 100644 --- a/dbms/src/Flash/Mpp/LocalRequestHandler.h +++ b/dbms/src/Flash/Mpp/LocalRequestHandler.h @@ -33,10 +33,15 @@ struct LocalRequestHandler , channel_writer(std::move(channel_writer_)) {} - template + template bool write(size_t source_index, const TrackedMppDataPacketPtr & tracked_packet) { - return channel_writer.write(source_index, tracked_packet); + return channel_writer.write(source_index, tracked_packet); + } + + bool isReadyForWrite() const + { + return channel_writer.isReadyForWrite(); } void writeDone(bool meet_error, const String & local_err_msg) const diff --git a/dbms/src/Flash/Mpp/MPPTask.cpp b/dbms/src/Flash/Mpp/MPPTask.cpp index 4f97a94afd7..290e25f093d 100644 --- a/dbms/src/Flash/Mpp/MPPTask.cpp +++ b/dbms/src/Flash/Mpp/MPPTask.cpp @@ -154,7 +154,7 @@ void MPPTask::run() void MPPTask::registerTunnels(const mpp::DispatchTaskRequest & task_request) { - auto tunnel_set_local = std::make_shared(*dag_context, log->identifier()); + auto tunnel_set_local = std::make_shared(log->identifier()); std::chrono::seconds timeout(task_request.timeout()); const auto & exchange_sender = dag_req.root_executor().exchange_sender(); diff --git a/dbms/src/Flash/Mpp/MPPTunnel.cpp b/dbms/src/Flash/Mpp/MPPTunnel.cpp index 4978e6f4f27..2f5db9b7810 100644 --- a/dbms/src/Flash/Mpp/MPPTunnel.cpp +++ b/dbms/src/Flash/Mpp/MPPTunnel.cpp @@ -20,6 +20,7 @@ #include #include +#include namespace DB { @@ -170,6 +171,19 @@ void MPPTunnel::write(TrackedMppDataPacketPtr && data) throw Exception(fmt::format("write to tunnel {} which is already closed, {}", tunnel_id, tunnel_sender->isConsumerFinished() ? tunnel_sender->getConsumerFinishMsg() : "")); } +void MPPTunnel::nonBlockingWrite(TrackedMppDataPacketPtr && data) +{ + LOG_TRACE(log, "start non blocking writing"); + auto pushed_data_size = data->getPacket().ByteSizeLong(); + if (tunnel_sender->nonBlockingPush(std::move(data))) + { + updateMetric(data_size_in_queue, pushed_data_size, mode); + updateConnProfileInfo(pushed_data_size); + return; + } + throw Exception(fmt::format("write to tunnel {} which is already closed, {}", tunnel_id, tunnel_sender->isConsumerFinished() ? tunnel_sender->getConsumerFinishMsg() : "")); +} + /// done normally and being called exactly once after writing all packets void MPPTunnel::writeDone() { @@ -314,23 +328,39 @@ void MPPTunnel::waitUntilConnectedOrFinished(std::unique_lock & lk) throw Exception(fmt::format("MPPTunnel {} can not be connected because MPPTask is cancelled", tunnel_id)); } -StringRef MPPTunnel::statusToString() +bool MPPTunnel::isReadyForWrite() const { + std::unique_lock lk(mu); switch (status) { case TunnelStatus::Unconnected: - return "Unconnected"; + { + if (timeout.count() > 0) + { + fiu_do_on(FailPoints::random_tunnel_wait_timeout_failpoint, throw Exception(tunnel_id + " is timeout");); + if (unlikely(!timeout_stopwatch)) + timeout_stopwatch.emplace(CLOCK_MONOTONIC_COARSE); + if (timeout_stopwatch->elapsedSeconds() > timeout.count()) + throw Exception(tunnel_id + " is timeout"); + } + return false; + } case TunnelStatus::Connected: - return "Connected"; - case TunnelStatus::WaitingForSenderFinish: - return "WaitingForSenderFinish"; - case TunnelStatus::Finished: - return "Finished"; + RUNTIME_CHECK_MSG(tunnel_sender != nullptr, "write to tunnel {} which is already closed.", tunnel_id); + return tunnel_sender->isReadyForWrite(); default: - RUNTIME_ASSERT(false, log, "Unknown TaskStatus {}", static_cast(status)); + // Returns true directly for TunnelStatus::WaitingForSenderFinish and TunnelStatus::Finished, + // and then handled by `nonBlockingWrite`. + RUNTIME_CHECK_MSG(tunnel_sender != nullptr, "write to tunnel {} which is already closed.", tunnel_id); + return true; } } +std::string_view MPPTunnel::statusToString() +{ + return magic_enum::enum_name(status); +} + void TunnelSender::consumerFinish(const String & msg) { LOG_TRACE(log, "calling consumer Finish"); diff --git a/dbms/src/Flash/Mpp/MPPTunnel.h b/dbms/src/Flash/Mpp/MPPTunnel.h index c857506901b..618bec38387 100644 --- a/dbms/src/Flash/Mpp/MPPTunnel.h +++ b/dbms/src/Flash/Mpp/MPPTunnel.h @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include @@ -103,11 +104,14 @@ class TunnelSender : private boost::noncopyable } virtual bool push(TrackedMppDataPacketPtr &&) = 0; + virtual bool nonBlockingPush(TrackedMppDataPacketPtr &&) = 0; virtual void cancelWith(const String &) = 0; virtual bool finish() = 0; + virtual bool isReadyForWrite() const = 0; + void consumerFinish(const String & err_msg); String getConsumerFinishMsg() { @@ -184,6 +188,11 @@ class SyncTunnelSender : public TunnelSender return send_queue.push(std::move(data)) == MPMCQueueResult::OK; } + bool nonBlockingPush(TrackedMppDataPacketPtr && data) override + { + return send_queue.nonBlockingPush(std::move(data)) == MPMCQueueResult::OK; + } + void cancelWith(const String & reason) override { send_queue.cancelWith(reason); @@ -194,6 +203,11 @@ class SyncTunnelSender : public TunnelSender return send_queue.finish(); } + bool isReadyForWrite() const override + { + return !send_queue.isFull(); + } + private: friend class tests::TestMPPTunnel; void sendJob(PacketWriter * writer); @@ -221,11 +235,21 @@ class AsyncTunnelSender : public TunnelSender return queue.push(std::move(data)); } + bool nonBlockingPush(TrackedMppDataPacketPtr && data) override + { + return queue.nonBlockingPush(std::move(data)); + } + bool finish() override { return queue.finish(); } + bool isReadyForWrite() const override + { + return !queue.isFull(); + } + void cancelWith(const String & reason) override { queue.cancelWith(reason); @@ -289,7 +313,20 @@ class LocalTunnelSenderV2 : public TunnelSender // is responsible for deleting receiver_mem_tracker must be destroyed after these local tunnels. data->switchMemTracker(local_request_handler.recv_mem_tracker); - return local_request_handler.write(source_index, data); + return local_request_handler.write(source_index, data); + } + + bool nonBlockingPush(TrackedMppDataPacketPtr && data) override + { + if (unlikely(checkPacketErr(data))) + return false; + + // receiver_mem_tracker pointer will always be valid because ExchangeReceiverBase won't be destructed + // before all local tunnels are destructed so that the MPPTask which contains ExchangeReceiverBase and + // is responsible for deleting receiver_mem_tracker must be destroyed after these local tunnels. + data->switchMemTracker(local_request_handler.recv_mem_tracker); + + return local_request_handler.write(source_index, data); } void cancelWith(const String & reason) override @@ -303,6 +340,11 @@ class LocalTunnelSenderV2 : public TunnelSender return true; } + bool isReadyForWrite() const override + { + return local_request_handler.isReadyForWrite(); + } + private: friend class tests::TestMPPTunnel; @@ -352,6 +394,11 @@ class LocalTunnelSenderV1 : public TunnelSender return send_queue.push(std::move(data)) == MPMCQueueResult::OK; } + bool nonBlockingPush(TrackedMppDataPacketPtr && data) override + { + return send_queue.nonBlockingPush(std::move(data)) == MPMCQueueResult::OK; + } + void cancelWith(const String & reason) override { send_queue.cancelWith(reason); @@ -362,6 +409,11 @@ class LocalTunnelSenderV1 : public TunnelSender return send_queue.finish(); } + bool isReadyForWrite() const override + { + return !send_queue.isFull(); + } + private: bool cancel_reason_sent = false; ConcurrentIOQueue send_queue; @@ -428,6 +480,15 @@ class MPPTunnel : private boost::noncopyable // write a single packet to the tunnel's send queue, it will block if tunnel is not ready. void write(TrackedMppDataPacketPtr && data); + // nonBlockingWrite write a single packet to the tunnel's send queue without blocking, + // and need to call isReadForWrite first. + // ``` + // while (!isReadyForWrite()) {} + // nonBlockingWrite(std::move(data)); + // ``` + void nonBlockingWrite(TrackedMppDataPacketPtr && data); + bool isReadyForWrite() const; + // finish the writing, and wait until the sender finishes. void writeDone(); @@ -475,7 +536,7 @@ class MPPTunnel : private boost::noncopyable Finished // Final state, no more work to do }; - StringRef statusToString(); + std::string_view statusToString(); void waitUntilConnectedOrFinished(std::unique_lock & lk); @@ -492,12 +553,14 @@ class MPPTunnel : private boost::noncopyable connection_profile_info.packets += 1; } - std::mutex mu; +private: + mutable std::mutex mu; std::condition_variable cv_for_status_changed; TunnelStatus status; std::chrono::seconds timeout; + mutable std::optional timeout_stopwatch; // tunnel id is in the format like "tunnel[sender]+[receiver]" String tunnel_id; diff --git a/dbms/src/Flash/Mpp/MPPTunnelSet.cpp b/dbms/src/Flash/Mpp/MPPTunnelSet.cpp index a308a9717a3..890b0e41b5c 100644 --- a/dbms/src/Flash/Mpp/MPPTunnelSet.cpp +++ b/dbms/src/Flash/Mpp/MPPTunnelSet.cpp @@ -13,11 +13,8 @@ // limitations under the License. #include -#include -#include #include #include -#include #include #include #include @@ -26,27 +23,14 @@ namespace DB { namespace { -void checkPacketSize(size_t size) -{ - static constexpr size_t max_packet_size = 1u << 31; - RUNTIME_CHECK(size < max_packet_size, fmt::format("Packet is too large to send, size : {}", size)); -} - TrackedMppDataPacketPtr serializePacket(const tipb::SelectResponse & response) { auto tracked_packet = std::make_shared(MPPDataPacketV0); tracked_packet->serializeByResponse(response); - checkPacketSize(tracked_packet->getPacket().ByteSizeLong()); return tracked_packet; } } // namespace -template -MPPTunnelSetBase::MPPTunnelSetBase(DAGContext & dag_context, const String & req_id) - : log(Logger::get(req_id)) - , result_field_types(dag_context.result_field_types) -{} - template void MPPTunnelSetBase::sendExecutionSummary(const tipb::SelectResponse & response) { @@ -56,196 +40,51 @@ void MPPTunnelSetBase::sendExecutionSummary(const tipb::SelectResponse & } template -void MPPTunnelSetBase::write(tipb::SelectResponse & response) -{ - // for root mpp task, only one tunnel will connect to tidb/tispark. - RUNTIME_CHECK(1 == tunnels.size()); - tunnels.back()->write(serializePacket(response)); -} - -static inline void updatePartitionWriterMetrics(size_t packet_bytes, bool is_local) -{ - // statistic - GET_METRIC(tiflash_exchange_data_bytes, type_hash_original).Increment(packet_bytes); - // compression method is always NONE - if (is_local) - GET_METRIC(tiflash_exchange_data_bytes, type_hash_none_compression_local).Increment(packet_bytes); - else - GET_METRIC(tiflash_exchange_data_bytes, type_hash_none_compression_remote).Increment(packet_bytes); -} - -static inline void updatePartitionWriterMetrics(CompressionMethod method, size_t original_size, size_t sz, bool is_local) +void MPPTunnelSetBase::write(TrackedMppDataPacketPtr && data, size_t index) { - // statistic - GET_METRIC(tiflash_exchange_data_bytes, type_hash_original).Increment(original_size); - - switch (method) - { - case CompressionMethod::NONE: - { - if (is_local) - { - GET_METRIC(tiflash_exchange_data_bytes, type_hash_none_compression_local).Increment(sz); - } - else - { - GET_METRIC(tiflash_exchange_data_bytes, type_hash_none_compression_remote).Increment(sz); - } - break; - } - case CompressionMethod::LZ4: - { - GET_METRIC(tiflash_exchange_data_bytes, type_hash_lz4_compression).Increment(sz); - break; - } - case CompressionMethod::ZSTD: - { - GET_METRIC(tiflash_exchange_data_bytes, type_hash_zstd_compression).Increment(sz); - break; - } - default: - break; - } + assert(index < tunnels.size()); + tunnels[index]->write(std::move(data)); } template -void MPPTunnelSetBase::broadcastOrPassThroughWrite(Blocks & blocks) +void MPPTunnelSetBase::nonBlockingWrite(TrackedMppDataPacketPtr && data, size_t index) { - RUNTIME_CHECK(!tunnels.empty()); - auto && tracked_packet = MPPTunnelSetHelper::ToPacketV0(blocks, result_field_types); - if (!tracked_packet) - return; - - auto packet_bytes = tracked_packet->getPacket().ByteSizeLong(); - checkPacketSize(packet_bytes); - // TODO avoid copy packet for broadcast. - for (size_t i = 1; i < tunnels.size(); ++i) - tunnels[i]->write(tracked_packet->copy()); - tunnels[0]->write(std::move(tracked_packet)); - { - // statistic - size_t data_bytes = 0; - size_t local_data_bytes = 0; - { - auto tunnel_cnt = getPartitionNum(); - size_t local_tunnel_cnt = 0; - for (size_t i = 0; i < tunnel_cnt; ++i) - { - local_tunnel_cnt += isLocal(i); - } - data_bytes = packet_bytes * tunnel_cnt; - local_data_bytes = packet_bytes * local_tunnel_cnt; - } - GET_METRIC(tiflash_exchange_data_bytes, type_broadcast_passthrough_original).Increment(data_bytes); - GET_METRIC(tiflash_exchange_data_bytes, type_broadcast_passthrough_none_compression_local).Increment(local_data_bytes); - GET_METRIC(tiflash_exchange_data_bytes, type_broadcast_passthrough_none_compression_remote).Increment(data_bytes - local_data_bytes); - } + assert(index < tunnels.size()); + tunnels[index]->nonBlockingWrite(std::move(data)); } template -void MPPTunnelSetBase::partitionWrite(Blocks & blocks, int16_t partition_id) +void MPPTunnelSetBase::write(tipb::SelectResponse & response, size_t index) { - auto && tracked_packet = MPPTunnelSetHelper::ToPacketV0(blocks, result_field_types); - if (!tracked_packet) - return; - auto packet_bytes = tracked_packet->getPacket().ByteSizeLong(); - checkPacketSize(packet_bytes); - tunnels[partition_id]->write(std::move(tracked_packet)); - updatePartitionWriterMetrics(packet_bytes, isLocal(partition_id)); + assert(index < tunnels.size()); + tunnels[index]->write(serializePacket(response)); } template -void MPPTunnelSetBase::partitionWrite( - const Block & header, - std::vector && part_columns, - int16_t partition_id, - MPPDataPacketVersion version, - CompressionMethod compression_method) +void MPPTunnelSetBase::nonBlockingWrite(tipb::SelectResponse & response, size_t index) { - assert(version > MPPDataPacketV0); - - bool is_local = isLocal(partition_id); - compression_method = is_local ? CompressionMethod::NONE : compression_method; - - size_t original_size = 0; - auto tracked_packet = MPPTunnelSetHelper::ToPacket(header, std::move(part_columns), version, compression_method, original_size); - if (!tracked_packet) - return; - - auto packet_bytes = tracked_packet->getPacket().ByteSizeLong(); - checkPacketSize(packet_bytes); - tunnels[partition_id]->write(std::move(tracked_packet)); - updatePartitionWriterMetrics(compression_method, original_size, packet_bytes, is_local); + assert(index < tunnels.size()); + tunnels[index]->nonBlockingWrite(serializePacket(response)); } template -void MPPTunnelSetBase::fineGrainedShuffleWrite( - const Block & header, - std::vector & scattered, - size_t bucket_idx, - UInt64 fine_grained_shuffle_stream_count, - size_t num_columns, - int16_t partition_id, - MPPDataPacketVersion version, - CompressionMethod compression_method) +bool MPPTunnelSetBase::isReadyForWrite() const { - if (version == MPPDataPacketV0) - return fineGrainedShuffleWrite(header, scattered, bucket_idx, fine_grained_shuffle_stream_count, num_columns, partition_id); - - bool is_local = isLocal(partition_id); - compression_method = is_local ? CompressionMethod::NONE : compression_method; - - size_t original_size = 0; - auto tracked_packet = MPPTunnelSetHelper::ToFineGrainedPacket( - header, - scattered, - bucket_idx, - fine_grained_shuffle_stream_count, - num_columns, - version, - compression_method, - original_size); - - if unlikely (tracked_packet->getPacket().chunks_size() <= 0) - return; - - auto packet_bytes = tracked_packet->getPacket().ByteSizeLong(); - checkPacketSize(packet_bytes); - tunnels[partition_id]->write(std::move(tracked_packet)); - updatePartitionWriterMetrics(compression_method, original_size, packet_bytes, is_local); -} - -template -void MPPTunnelSetBase::fineGrainedShuffleWrite( - const Block & header, - std::vector & scattered, - size_t bucket_idx, - UInt64 fine_grained_shuffle_stream_count, - size_t num_columns, - int16_t partition_id) -{ - auto tracked_packet = MPPTunnelSetHelper::ToFineGrainedPacketV0( - header, - scattered, - bucket_idx, - fine_grained_shuffle_stream_count, - num_columns, - result_field_types); - - if unlikely (tracked_packet->getPacket().chunks_size() <= 0) - return; - - auto packet_bytes = tracked_packet->getPacket().ByteSizeLong(); - checkPacketSize(packet_bytes); - tunnels[partition_id]->write(std::move(tracked_packet)); - updatePartitionWriterMetrics(packet_bytes, isLocal(partition_id)); + for (const auto & tunnel : tunnels) + { + if (!tunnel->isReadyForWrite()) + return false; + } + return true; } template void MPPTunnelSetBase::registerTunnel(const MPPTaskId & receiver_task_id, const TunnelPtr & tunnel) { - if (receiver_task_id_to_index_map.find(receiver_task_id) != receiver_task_id_to_index_map.end()) - throw Exception(fmt::format("the tunnel {} has been registered", tunnel->id())); + RUNTIME_CHECK_MSG( + receiver_task_id_to_index_map.find(receiver_task_id) == receiver_task_id_to_index_map.end(), + "the tunnel {} has been registered", + tunnel->id()); receiver_task_id_to_index_map[receiver_task_id] = tunnels.size(); tunnels.push_back(tunnel); diff --git a/dbms/src/Flash/Mpp/MPPTunnelSet.h b/dbms/src/Flash/Mpp/MPPTunnelSet.h index 4f2ff4e7b33..ab1d96b6e9c 100644 --- a/dbms/src/Flash/Mpp/MPPTunnelSet.h +++ b/dbms/src/Flash/Mpp/MPPTunnelSet.h @@ -28,39 +28,18 @@ class MPPTunnelSetBase : private boost::noncopyable { public: using TunnelPtr = std::shared_ptr; - MPPTunnelSetBase(DAGContext & dag_context, const String & req_id); - - // this is a root mpp writing. - void write(tipb::SelectResponse & response); - // this is a broadcast or pass through writing. - // data codec version V0 - void broadcastOrPassThroughWrite(Blocks & blocks); - // this is a partition writing. - // data codec version V0 - void partitionWrite(Blocks & blocks, int16_t partition_id); - // data codec version > V0 - void partitionWrite(const Block & header, std::vector && part_columns, int16_t partition_id, MPPDataPacketVersion version, CompressionMethod compression_method); - // this is a fine grained shuffle writing. - // data codec version V0 - void fineGrainedShuffleWrite( - const Block & header, - std::vector & scattered, - size_t bucket_idx, - UInt64 fine_grained_shuffle_stream_count, - size_t num_columns, - int16_t partition_id); - void fineGrainedShuffleWrite( - const Block & header, - std::vector & scattered, - size_t bucket_idx, - UInt64 fine_grained_shuffle_stream_count, - size_t num_columns, - int16_t partition_id, - MPPDataPacketVersion version, - CompressionMethod compression_method); + explicit MPPTunnelSetBase(const String & req_id) + : log(Logger::get(req_id)) + {} + + void write(TrackedMppDataPacketPtr && data, size_t index); + void nonBlockingWrite(TrackedMppDataPacketPtr && data, size_t index); + + void write(tipb::SelectResponse & response, size_t index); + void nonBlockingWrite(tipb::SelectResponse & response, size_t index); + /// this is a execution summary writing. - /// for both broadcast writing and partition/fine grained shuffle writing, only - /// return meaningful execution summary for the first tunnel, + /// only return meaningful execution summary for the first tunnel, /// because in TiDB, it does not know enough information /// about the execution details for the mpp query, it just /// add up all the execution summaries for the same executor, @@ -84,7 +63,8 @@ class MPPTunnelSetBase : private boost::noncopyable const std::vector & getTunnels() const { return tunnels; } -private: + bool isReadyForWrite() const; + bool isLocal(size_t index) const; private: @@ -92,8 +72,6 @@ class MPPTunnelSetBase : private boost::noncopyable std::unordered_map receiver_task_id_to_index_map; const LoggerPtr log; - std::vector result_field_types; - int external_thread_cnt = 0; }; diff --git a/dbms/src/Flash/Mpp/MPPTunnelSetWriter.cpp b/dbms/src/Flash/Mpp/MPPTunnelSetWriter.cpp new file mode 100644 index 00000000000..c09803fcbe9 --- /dev/null +++ b/dbms/src/Flash/Mpp/MPPTunnelSetWriter.cpp @@ -0,0 +1,244 @@ +// Copyright 2022 PingCAP, Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace DB +{ +namespace +{ +void checkPacketSize(size_t size) +{ + static constexpr size_t max_packet_size = 1u << 31; + RUNTIME_CHECK(size < max_packet_size, fmt::format("Packet is too large to send, size : {}", size)); +} + +void updatePartitionWriterMetrics(size_t packet_bytes, bool is_local) +{ + // statistic + GET_METRIC(tiflash_exchange_data_bytes, type_hash_original).Increment(packet_bytes); + // compression method is always NONE + if (is_local) + GET_METRIC(tiflash_exchange_data_bytes, type_hash_none_compression_local).Increment(packet_bytes); + else + GET_METRIC(tiflash_exchange_data_bytes, type_hash_none_compression_remote).Increment(packet_bytes); +} + +void updatePartitionWriterMetrics(CompressionMethod method, size_t original_size, size_t sz, bool is_local) +{ + // statistic + GET_METRIC(tiflash_exchange_data_bytes, type_hash_original).Increment(original_size); + + switch (method) + { + case CompressionMethod::NONE: + { + if (is_local) + { + GET_METRIC(tiflash_exchange_data_bytes, type_hash_none_compression_local).Increment(sz); + } + else + { + GET_METRIC(tiflash_exchange_data_bytes, type_hash_none_compression_remote).Increment(sz); + } + break; + } + case CompressionMethod::LZ4: + { + GET_METRIC(tiflash_exchange_data_bytes, type_hash_lz4_compression).Increment(sz); + break; + } + case CompressionMethod::ZSTD: + { + GET_METRIC(tiflash_exchange_data_bytes, type_hash_zstd_compression).Increment(sz); + break; + } + default: + break; + } +} +} // namespace + +MPPTunnelSetWriterBase::MPPTunnelSetWriterBase( + const MPPTunnelSetPtr & mpp_tunnel_set_, + const std::vector & result_field_types_, + const String & req_id) + : mpp_tunnel_set(mpp_tunnel_set_) + , result_field_types(result_field_types_) + , log(Logger::get(req_id)) +{ + RUNTIME_CHECK(mpp_tunnel_set->getPartitionNum() > 0); +} + +void MPPTunnelSetWriterBase::write(tipb::SelectResponse & response) +{ + checkPacketSize(response.ByteSizeLong()); + // for root mpp task, only one tunnel will connect to tidb/tispark. + writeToTunnel(response, 0); +} + +void MPPTunnelSetWriterBase::broadcastOrPassThroughWrite(Blocks & blocks) +{ + auto && tracked_packet = MPPTunnelSetHelper::ToPacketV0(blocks, result_field_types); + if (!tracked_packet) + return; + + auto packet_bytes = tracked_packet->getPacket().ByteSizeLong(); + checkPacketSize(packet_bytes); + // TODO avoid copy packet for broadcast. + for (size_t i = 1; i < getPartitionNum(); ++i) + writeToTunnel(tracked_packet->copy(), i); + writeToTunnel(std::move(tracked_packet), 0); + { + // statistic + size_t data_bytes = 0; + size_t local_data_bytes = 0; + { + auto tunnel_cnt = getPartitionNum(); + size_t local_tunnel_cnt = 0; + for (size_t i = 0; i < tunnel_cnt; ++i) + { + local_tunnel_cnt += mpp_tunnel_set->isLocal(i); + } + data_bytes = packet_bytes * tunnel_cnt; + local_data_bytes = packet_bytes * local_tunnel_cnt; + } + GET_METRIC(tiflash_exchange_data_bytes, type_broadcast_passthrough_original).Increment(data_bytes); + GET_METRIC(tiflash_exchange_data_bytes, type_broadcast_passthrough_none_compression_local).Increment(local_data_bytes); + GET_METRIC(tiflash_exchange_data_bytes, type_broadcast_passthrough_none_compression_remote).Increment(data_bytes - local_data_bytes); + } +} + +void MPPTunnelSetWriterBase::partitionWrite(Blocks & blocks, int16_t partition_id) +{ + auto && tracked_packet = MPPTunnelSetHelper::ToPacketV0(blocks, result_field_types); + if (!tracked_packet) + return; + auto packet_bytes = tracked_packet->getPacket().ByteSizeLong(); + checkPacketSize(packet_bytes); + writeToTunnel(std::move(tracked_packet), partition_id); + updatePartitionWriterMetrics(packet_bytes, mpp_tunnel_set->isLocal(partition_id)); +} + +void MPPTunnelSetWriterBase::partitionWrite( + const Block & header, + std::vector && part_columns, + int16_t partition_id, + MPPDataPacketVersion version, + CompressionMethod compression_method) +{ + assert(version > MPPDataPacketV0); + + bool is_local = mpp_tunnel_set->isLocal(partition_id); + compression_method = is_local ? CompressionMethod::NONE : compression_method; + + size_t original_size = 0; + auto tracked_packet = MPPTunnelSetHelper::ToPacket(header, std::move(part_columns), version, compression_method, original_size); + if (!tracked_packet) + return; + + auto packet_bytes = tracked_packet->getPacket().ByteSizeLong(); + checkPacketSize(packet_bytes); + writeToTunnel(std::move(tracked_packet), partition_id); + updatePartitionWriterMetrics(compression_method, original_size, packet_bytes, is_local); +} + +void MPPTunnelSetWriterBase::fineGrainedShuffleWrite( + const Block & header, + std::vector & scattered, + size_t bucket_idx, + UInt64 fine_grained_shuffle_stream_count, + size_t num_columns, + int16_t partition_id, + MPPDataPacketVersion version, + CompressionMethod compression_method) +{ + if (version == MPPDataPacketV0) + return fineGrainedShuffleWrite(header, scattered, bucket_idx, fine_grained_shuffle_stream_count, num_columns, partition_id); + + bool is_local = mpp_tunnel_set->isLocal(partition_id); + compression_method = is_local ? CompressionMethod::NONE : compression_method; + + size_t original_size = 0; + auto tracked_packet = MPPTunnelSetHelper::ToFineGrainedPacket( + header, + scattered, + bucket_idx, + fine_grained_shuffle_stream_count, + num_columns, + version, + compression_method, + original_size); + + if unlikely (tracked_packet->getPacket().chunks_size() <= 0) + return; + + auto packet_bytes = tracked_packet->getPacket().ByteSizeLong(); + checkPacketSize(packet_bytes); + writeToTunnel(std::move(tracked_packet), partition_id); + updatePartitionWriterMetrics(compression_method, original_size, packet_bytes, is_local); +} + +void MPPTunnelSetWriterBase::fineGrainedShuffleWrite( + const Block & header, + std::vector & scattered, + size_t bucket_idx, + UInt64 fine_grained_shuffle_stream_count, + size_t num_columns, + int16_t partition_id) +{ + auto tracked_packet = MPPTunnelSetHelper::ToFineGrainedPacketV0( + header, + scattered, + bucket_idx, + fine_grained_shuffle_stream_count, + num_columns, + result_field_types); + + if unlikely (tracked_packet->getPacket().chunks_size() <= 0) + return; + + auto packet_bytes = tracked_packet->getPacket().ByteSizeLong(); + checkPacketSize(packet_bytes); + writeToTunnel(std::move(tracked_packet), partition_id); + updatePartitionWriterMetrics(packet_bytes, mpp_tunnel_set->isLocal(partition_id)); +} + +void SyncMPPTunnelSetWriter::writeToTunnel(TrackedMppDataPacketPtr && data, size_t index) +{ + mpp_tunnel_set->write(std::move(data), index); +} + +void SyncMPPTunnelSetWriter::writeToTunnel(tipb::SelectResponse & response, size_t index) +{ + mpp_tunnel_set->write(response, index); +} + +void AsyncMPPTunnelSetWriter::writeToTunnel(TrackedMppDataPacketPtr && data, size_t index) +{ + mpp_tunnel_set->nonBlockingWrite(std::move(data), index); +} + +void AsyncMPPTunnelSetWriter::writeToTunnel(tipb::SelectResponse & response, size_t index) +{ + mpp_tunnel_set->nonBlockingWrite(response, index); +} +} // namespace DB diff --git a/dbms/src/Flash/Mpp/MPPTunnelSetWriter.h b/dbms/src/Flash/Mpp/MPPTunnelSetWriter.h new file mode 100644 index 00000000000..87bb4b63e6f --- /dev/null +++ b/dbms/src/Flash/Mpp/MPPTunnelSetWriter.h @@ -0,0 +1,111 @@ +// Copyright 2022 PingCAP, Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +namespace DB +{ +class MPPTunnelSetWriterBase : private boost::noncopyable +{ +public: + MPPTunnelSetWriterBase( + const MPPTunnelSetPtr & mpp_tunnel_set_, + const std::vector & result_field_types_, + const String & req_id); + + virtual ~MPPTunnelSetWriterBase() = default; + + // this is a root mpp writing. + void write(tipb::SelectResponse & response); + // this is a broadcast or pass through writing. + // data codec version V0 + void broadcastOrPassThroughWrite(Blocks & blocks); + // this is a partition writing. + // data codec version V0 + void partitionWrite(Blocks & blocks, int16_t partition_id); + // data codec version > V0 + void partitionWrite(const Block & header, std::vector && part_columns, int16_t partition_id, MPPDataPacketVersion version, CompressionMethod compression_method); + // this is a fine grained shuffle writing. + // data codec version V0 + void fineGrainedShuffleWrite( + const Block & header, + std::vector & scattered, + size_t bucket_idx, + UInt64 fine_grained_shuffle_stream_count, + size_t num_columns, + int16_t partition_id); + void fineGrainedShuffleWrite( + const Block & header, + std::vector & scattered, + size_t bucket_idx, + UInt64 fine_grained_shuffle_stream_count, + size_t num_columns, + int16_t partition_id, + MPPDataPacketVersion version, + CompressionMethod compression_method); + + uint16_t getPartitionNum() const { return mpp_tunnel_set->getPartitionNum(); } + + virtual bool isReadyForWrite() const = 0; + +protected: + virtual void writeToTunnel(TrackedMppDataPacketPtr && data, size_t index) = 0; + virtual void writeToTunnel(tipb::SelectResponse & response, size_t index) = 0; + +protected: + MPPTunnelSetPtr mpp_tunnel_set; + std::vector result_field_types; + const LoggerPtr log; +}; + +class SyncMPPTunnelSetWriter : public MPPTunnelSetWriterBase +{ +public: + SyncMPPTunnelSetWriter( + const MPPTunnelSetPtr & mpp_tunnel_set_, + const std::vector & result_field_types_, + const String & req_id) + : MPPTunnelSetWriterBase(mpp_tunnel_set_, result_field_types_, req_id) + {} + + // For sync writer, `isReadyForWrite` will not be called, so an exception is thrown here. + bool isReadyForWrite() const override { throw Exception("Unsupport sync writer"); } + +protected: + void writeToTunnel(TrackedMppDataPacketPtr && data, size_t index) override; + void writeToTunnel(tipb::SelectResponse & response, size_t index) override; +}; +using SyncMPPTunnelSetWriterPtr = std::shared_ptr; + +class AsyncMPPTunnelSetWriter : public MPPTunnelSetWriterBase +{ +public: + AsyncMPPTunnelSetWriter( + const MPPTunnelSetPtr & mpp_tunnel_set_, + const std::vector & result_field_types_, + const String & req_id) + : MPPTunnelSetWriterBase(mpp_tunnel_set_, result_field_types_, req_id) + {} + + bool isReadyForWrite() const override { return mpp_tunnel_set->isReadyForWrite(); } + +protected: + void writeToTunnel(TrackedMppDataPacketPtr && data, size_t index) override; + void writeToTunnel(tipb::SelectResponse & response, size_t index) override; +}; +using AsyncMPPTunnelSetWriterPtr = std::shared_ptr; + +} // namespace DB diff --git a/dbms/src/Flash/Mpp/ReceiverChannelWriter.cpp b/dbms/src/Flash/Mpp/ReceiverChannelWriter.cpp index cd791cc94d5..849c9620f39 100644 --- a/dbms/src/Flash/Mpp/ReceiverChannelWriter.cpp +++ b/dbms/src/Flash/Mpp/ReceiverChannelWriter.cpp @@ -37,7 +37,12 @@ inline void injectFailPointReceiverPushFail(bool & push_succeed [[maybe_unused]] } } // namespace -bool ReceiverChannelWriter::writeFineGrain(size_t source_index, const TrackedMppDataPacketPtr & tracked_packet, const mpp::Error * error_ptr, const String * resp_ptr) +bool ReceiverChannelWriter::writeFineGrain( + WriteToChannelFunc write_func, + size_t source_index, + const TrackedMppDataPacketPtr & tracked_packet, + const mpp::Error * error_ptr, + const String * resp_ptr) { bool success = true; auto & packet = tracked_packet->packet; @@ -73,14 +78,14 @@ bool ReceiverChannelWriter::writeFineGrain(size_t source_index, const TrackedMpp if (resp_ptr == nullptr && error_ptr == nullptr && chunks[i].empty()) continue; - std::shared_ptr recv_msg = std::make_shared( + auto recv_msg = std::make_shared( source_index, req_info, tracked_packet, error_ptr, resp_ptr, std::move(chunks[i])); - success = (*msg_channels)[i]->push(std::move(recv_msg)) == MPMCQueueResult::OK; + success = (write_func(i, std::move(recv_msg)) == MPMCQueueResult::OK); injectFailPointReceiverPushFail(success, mode); @@ -90,7 +95,12 @@ bool ReceiverChannelWriter::writeFineGrain(size_t source_index, const TrackedMpp return success; } -bool ReceiverChannelWriter::writeNonFineGrain(size_t source_index, const TrackedMppDataPacketPtr & tracked_packet, const mpp::Error * error_ptr, const String * resp_ptr) +bool ReceiverChannelWriter::writeNonFineGrain( + WriteToChannelFunc write_func, + size_t source_index, + const TrackedMppDataPacketPtr & tracked_packet, + const mpp::Error * error_ptr, + const String * resp_ptr) { bool success = true; auto & packet = tracked_packet->packet; @@ -101,7 +111,7 @@ bool ReceiverChannelWriter::writeNonFineGrain(size_t source_index, const Tracked if (!(resp_ptr == nullptr && error_ptr == nullptr && chunks.empty())) { - std::shared_ptr recv_msg = std::make_shared( + auto recv_msg = std::make_shared( source_index, req_info, tracked_packet, @@ -109,9 +119,19 @@ bool ReceiverChannelWriter::writeNonFineGrain(size_t source_index, const Tracked resp_ptr, std::move(chunks)); - success = (*msg_channels)[0]->push(std::move(recv_msg)) == MPMCQueueResult::OK; + success = write_func(0, std::move(recv_msg)) == MPMCQueueResult::OK; injectFailPointReceiverPushFail(success, mode); } return success; } + +bool ReceiverChannelWriter::isReadyForWrite() const +{ + for (const auto & msg_channel : *msg_channels) + { + if (msg_channel->isFull()) + return false; + } + return true; +} } // namespace DB diff --git a/dbms/src/Flash/Mpp/ReceiverChannelWriter.h b/dbms/src/Flash/Mpp/ReceiverChannelWriter.h index c609fb75f7d..5d35b565b74 100644 --- a/dbms/src/Flash/Mpp/ReceiverChannelWriter.h +++ b/dbms/src/Flash/Mpp/ReceiverChannelWriter.h @@ -19,6 +19,8 @@ #include #include +#include + namespace DB { namespace FailPoints @@ -78,6 +80,7 @@ struct ReceivedMessage packet->switchMemTracker(current_memory_tracker); } }; +using ReceivedMessagePtr = std::shared_ptr; enum class ReceiverMode { @@ -99,25 +102,42 @@ class ReceiverChannelWriter , mode(mode_) {} - // "write" means writing the packet to the channel which is a MPMCQueue. + // "write" means writing the packet to the channel which is a ConcurrentIOQueue. + // + // If non_blocking: + // call ConcurrentIOQueue::nonBlockingPush + // If !non_blocking: + // call ConcurrentIOQueue::push // // If enable_fine_grained_shuffle: // Seperate chunks according to packet.stream_ids[i], then push to msg_channels[stream_id]. // If fine grained_shuffle is disabled: // Push all chunks to msg_channels[0]. + // // Return true if all push succeed, otherwise return false. // NOTE: shared_ptr will be hold by all ExchangeReceiverBlockInputStream to make chunk pointer valid. - template + template bool write(size_t source_index, const TrackedMppDataPacketPtr & tracked_packet) { - const mpp::Error * error_ptr = getErrorPtr(tracked_packet->packet); - const String * resp_ptr = getRespPtr(tracked_packet->packet); + const auto & packet = tracked_packet->packet; + const mpp::Error * error_ptr = packet.has_error() ? &packet.error() : nullptr; + const String * resp_ptr = packet.data().empty() ? nullptr : &packet.data(); + + WriteToChannelFunc write_func; + if constexpr (non_blocking) + write_func = [&](size_t i, ReceivedMessagePtr && recv_msg) { + return (*msg_channels)[i]->nonBlockingPush(std::move(recv_msg)); + }; + else + write_func = [&](size_t i, ReceivedMessagePtr && recv_msg) { + return (*msg_channels)[i]->push(std::move(recv_msg)); + }; bool success; if constexpr (enable_fine_grained_shuffle) - success = writeFineGrain(source_index, tracked_packet, error_ptr, resp_ptr); + success = writeFineGrain(write_func, source_index, tracked_packet, error_ptr, resp_ptr); else - success = writeNonFineGrain(source_index, tracked_packet, error_ptr, resp_ptr); + success = writeNonFineGrain(write_func, source_index, tracked_packet, error_ptr, resp_ptr); if (likely(success)) ExchangeReceiverMetric::addDataSizeMetric(*data_size_in_queue, tracked_packet->getPacket().ByteSizeLong()); @@ -125,23 +145,23 @@ class ReceiverChannelWriter return success; } -private: - static const mpp::Error * getErrorPtr(const mpp::MPPDataPacket & packet) - { - if (unlikely(packet.has_error())) - return &packet.error(); - return nullptr; - } + bool isReadyForWrite() const; - static const String * getRespPtr(const mpp::MPPDataPacket & packet) - { - if (unlikely(!packet.data().empty())) - return &packet.data(); - return nullptr; - } - - bool writeFineGrain(size_t source_index, const TrackedMppDataPacketPtr & tracked_packet, const mpp::Error * error_ptr, const String * resp_ptr); - bool writeNonFineGrain(size_t source_index, const TrackedMppDataPacketPtr & tracked_packet, const mpp::Error * error_ptr, const String * resp_ptr); +private: + using WriteToChannelFunc = std::function; + + bool writeFineGrain( + WriteToChannelFunc write_func, + size_t source_index, + const TrackedMppDataPacketPtr & tracked_packet, + const mpp::Error * error_ptr, + const String * resp_ptr); + bool writeNonFineGrain( + WriteToChannelFunc write_func, + size_t source_index, + const TrackedMppDataPacketPtr & tracked_packet, + const mpp::Error * error_ptr, + const String * resp_ptr); std::atomic * data_size_in_queue; std::vector * msg_channels; diff --git a/dbms/src/Flash/Mpp/newMPPExchangeWriter.cpp b/dbms/src/Flash/Mpp/newMPPExchangeWriter.cpp new file mode 100644 index 00000000000..0d43b6fd162 --- /dev/null +++ b/dbms/src/Flash/Mpp/newMPPExchangeWriter.cpp @@ -0,0 +1,158 @@ +// Copyright 2022 PingCAP, Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include + +namespace DB +{ +namespace +{ +template +std::unique_ptr buildMPPExchangeWriter( + const ExchangeWriterPtr & writer, + const std::vector & partition_col_ids, + const TiDB::TiDBCollators & partition_col_collators, + const tipb::ExchangeType & exchange_type, + Int64 records_per_chunk, + Int64 batch_send_min_limit, + DAGContext & dag_context, + bool enable_fine_grained_shuffle, + UInt64 fine_grained_shuffle_stream_count, + UInt64 fine_grained_shuffle_batch_size, + tipb::CompressionMode compression_mode, + Int64 batch_send_min_limit_compression) +{ + if (dag_context.isRootMPPTask()) + { + // No need to use use data compression + RUNTIME_CHECK(compression_mode == tipb::CompressionMode::NONE); + + RUNTIME_CHECK(!enable_fine_grained_shuffle); + RUNTIME_CHECK(exchange_type == tipb::ExchangeType::PassThrough); + return std::make_unique>( + writer, + records_per_chunk, + batch_send_min_limit, + dag_context); + } + else + { + if (exchange_type == tipb::ExchangeType::Hash) + { + auto mpp_version = dag_context.getMPPTaskMeta().mpp_version(); + auto data_codec_version = mpp_version == MppVersionV0 + ? MPPDataPacketV0 + : MPPDataPacketV1; + + if (enable_fine_grained_shuffle) + { + return std::make_unique>( + writer, + partition_col_ids, + partition_col_collators, + dag_context, + fine_grained_shuffle_stream_count, + fine_grained_shuffle_batch_size, + data_codec_version, + compression_mode); + } + else + { + auto chosen_batch_send_min_limit = mpp_version == MppVersionV0 + ? batch_send_min_limit + : batch_send_min_limit_compression; + + return std::make_unique>( + writer, + partition_col_ids, + partition_col_collators, + chosen_batch_send_min_limit, + dag_context, + data_codec_version, + compression_mode); + } + } + else + { + // TODO: support data compression if necessary + RUNTIME_CHECK(compression_mode == tipb::CompressionMode::NONE); + + RUNTIME_CHECK(!enable_fine_grained_shuffle); + return std::make_unique>( + writer, + batch_send_min_limit, + dag_context); + } + } +} +} // namespace + +std::unique_ptr newMPPExchangeWriter( + const std::vector & partition_col_ids, + const TiDB::TiDBCollators & partition_col_collators, + const tipb::ExchangeType & exchange_type, + Int64 records_per_chunk, + Int64 batch_send_min_limit, + DAGContext & dag_context, + bool enable_fine_grained_shuffle, + UInt64 fine_grained_shuffle_stream_count, + UInt64 fine_grained_shuffle_batch_size, + tipb::CompressionMode compression_mode, + Int64 batch_send_min_limit_compression, + const String & req_id, + bool is_async) +{ + RUNTIME_CHECK_MSG(dag_context.isMPPTask() && dag_context.tunnel_set != nullptr, "exchange writer only run in MPP"); + if (is_async) + { + auto writer = std::make_shared(dag_context.tunnel_set, dag_context.result_field_types, req_id); + return buildMPPExchangeWriter( + writer, + partition_col_ids, + partition_col_collators, + exchange_type, + records_per_chunk, + batch_send_min_limit, + dag_context, + enable_fine_grained_shuffle, + fine_grained_shuffle_stream_count, + fine_grained_shuffle_batch_size, + compression_mode, + batch_send_min_limit_compression); + } + else + { + auto writer = std::make_shared(dag_context.tunnel_set, dag_context.result_field_types, req_id); + return buildMPPExchangeWriter( + writer, + partition_col_ids, + partition_col_collators, + exchange_type, + records_per_chunk, + batch_send_min_limit, + dag_context, + enable_fine_grained_shuffle, + fine_grained_shuffle_stream_count, + fine_grained_shuffle_batch_size, + compression_mode, + batch_send_min_limit_compression); + } +} +} // namespace DB diff --git a/dbms/src/Flash/Mpp/newMPPExchangeWriter.h b/dbms/src/Flash/Mpp/newMPPExchangeWriter.h index 9cb23d11734..e67eebb0c59 100644 --- a/dbms/src/Flash/Mpp/newMPPExchangeWriter.h +++ b/dbms/src/Flash/Mpp/newMPPExchangeWriter.h @@ -14,17 +14,12 @@ #pragma once -#include -#include -#include -#include -#include +#include +#include namespace DB { -template std::unique_ptr newMPPExchangeWriter( - const ExchangeWriterPtr & writer, const std::vector & partition_col_ids, const TiDB::TiDBCollators & partition_col_collators, const tipb::ExchangeType & exchange_type, @@ -35,71 +30,8 @@ std::unique_ptr newMPPExchangeWriter( UInt64 fine_grained_shuffle_stream_count, UInt64 fine_grained_shuffle_batch_size, tipb::CompressionMode compression_mode, - Int64 batch_send_min_limit_compression) -{ - RUNTIME_CHECK(dag_context.isMPPTask()); - if (dag_context.isRootMPPTask()) - { - // No need to use use data compression - RUNTIME_CHECK(compression_mode == tipb::CompressionMode::NONE); - - RUNTIME_CHECK(!enable_fine_grained_shuffle); - RUNTIME_CHECK(exchange_type == tipb::ExchangeType::PassThrough); - return std::make_unique>( - writer, - records_per_chunk, - batch_send_min_limit, - dag_context); - } - else - { - if (exchange_type == tipb::ExchangeType::Hash) - { - auto mpp_version = dag_context.getMPPTaskMeta().mpp_version(); - auto data_codec_version = mpp_version == MppVersionV0 - ? MPPDataPacketV0 - : MPPDataPacketV1; - - if (enable_fine_grained_shuffle) - { - return std::make_unique>( - writer, - partition_col_ids, - partition_col_collators, - dag_context, - fine_grained_shuffle_stream_count, - fine_grained_shuffle_batch_size, - data_codec_version, - compression_mode); - } - else - { - auto chosen_batch_send_min_limit = mpp_version == MppVersionV0 - ? batch_send_min_limit - : batch_send_min_limit_compression; - - return std::make_unique>( - writer, - partition_col_ids, - partition_col_collators, - chosen_batch_send_min_limit, - dag_context, - data_codec_version, - compression_mode); - } - } - else - { - // TODO: support data compression if necessary - RUNTIME_CHECK(compression_mode == tipb::CompressionMode::NONE); - - RUNTIME_CHECK(!enable_fine_grained_shuffle); - return std::make_unique>( - writer, - batch_send_min_limit, - dag_context); - } - } -} + Int64 batch_send_min_limit_compression, + const String & req_id, + bool is_async = false); } // namespace DB diff --git a/dbms/src/Flash/Mpp/tests/gtest_mpp_exchange_writer.cpp b/dbms/src/Flash/Mpp/tests/gtest_mpp_exchange_writer.cpp index 0ce4fda70dd..14757049eaa 100644 --- a/dbms/src/Flash/Mpp/tests/gtest_mpp_exchange_writer.cpp +++ b/dbms/src/Flash/Mpp/tests/gtest_mpp_exchange_writer.cpp @@ -193,12 +193,6 @@ struct MockExchangeWriter } static void write(tipb::SelectResponse &) { FAIL() << "cannot reach here, only consider CH Block format"; } - void sendExecutionSummary(const tipb::SelectResponse & response) - { - auto tracked_packet = std::make_shared(MPPDataPacketV0); - tracked_packet->serializeByResponse(response); - checker(tracked_packet, 0); - } uint16_t getPartitionNum() const { return part_num; } bool isLocal(size_t index) const { @@ -206,6 +200,7 @@ struct MockExchangeWriter // make only part 0 use local tunnel return index == 0; } + bool isReadyForWrite() const { throw Exception("Unsupport async write"); } private: MockExchangeWriterChecker checker; diff --git a/dbms/src/Flash/Mpp/tests/gtest_mpptunnel.cpp b/dbms/src/Flash/Mpp/tests/gtest_mpptunnel.cpp index a2395c2b217..06495526f34 100644 --- a/dbms/src/Flash/Mpp/tests/gtest_mpptunnel.cpp +++ b/dbms/src/Flash/Mpp/tests/gtest_mpptunnel.cpp @@ -780,5 +780,73 @@ TEST_F(TestMPPTunnel, LocalWriteAfterFinished) if (tunnel != nullptr) tunnel->waitForFinish(); } + +TEST_F(TestMPPTunnel, SyncTunnelNonBlockingWrite) +{ + auto writer_ptr = std::make_unique(); + auto mpp_tunnel_ptr = constructRemoteSyncTunnel(); + mpp_tunnel_ptr->connectSync(writer_ptr.get()); + GTEST_ASSERT_EQ(getTunnelConnectedFlag(mpp_tunnel_ptr), true); + + ASSERT_TRUE(mpp_tunnel_ptr->isReadyForWrite()); + mpp_tunnel_ptr->nonBlockingWrite(newDataPacket("First")); + mpp_tunnel_ptr->writeDone(); + GTEST_ASSERT_EQ(getTunnelFinishedFlag(mpp_tunnel_ptr), true); + + GTEST_ASSERT_EQ(writer_ptr->write_packet_vec.size(), 1); + GTEST_ASSERT_EQ(writer_ptr->write_packet_vec.back(), "First"); +} + +TEST_F(TestMPPTunnel, AsyncTunnelNonBlockingWrite) +{ + auto mpp_tunnel_ptr = constructRemoteAsyncTunnel(); + std::unique_ptr call_data = std::make_unique(); + mpp_tunnel_ptr->connectAsync(call_data.get()); + GTEST_ASSERT_EQ(getTunnelConnectedFlag(mpp_tunnel_ptr), true); + std::thread t(&MockAsyncCallData::run, call_data.get()); + + ASSERT_TRUE(mpp_tunnel_ptr->isReadyForWrite()); + mpp_tunnel_ptr->nonBlockingWrite(newDataPacket("First")); + mpp_tunnel_ptr->writeDone(); + GTEST_ASSERT_EQ(getTunnelFinishedFlag(mpp_tunnel_ptr), true); + t.join(); + + GTEST_ASSERT_EQ(call_data->write_packet_vec.size(), 1); + GTEST_ASSERT_EQ(call_data->write_packet_vec.back(), "First"); +} + +TEST_F(TestMPPTunnel, LocalTunnelNonBlockingWrite) +{ + auto [receiver, tunnels] = prepareLocal(1); + const auto & mpp_tunnel_ptr = tunnels.back(); + GTEST_ASSERT_EQ(getTunnelConnectedFlag(mpp_tunnel_ptr), true); + std::thread t(&MockExchangeReceiver::receiveAll, receiver.get()); + + ASSERT_TRUE(mpp_tunnel_ptr->isReadyForWrite()); + mpp_tunnel_ptr->nonBlockingWrite(newDataPacket("First")); + mpp_tunnel_ptr->writeDone(); + GTEST_ASSERT_EQ(getTunnelFinishedFlag(mpp_tunnel_ptr), true); + t.join(); + + GTEST_ASSERT_EQ(receiver->getReceivedMsgs().size(), 1); + GTEST_ASSERT_EQ(receiver->getReceivedMsgs().back()->packet->getPacket().data(), "First"); +} + +TEST_F(TestMPPTunnel, isReadyForWriteTimeout) +try +{ + timeout = std::chrono::seconds(1); + auto mpp_tunnel_ptr = constructRemoteSyncTunnel(); + Stopwatch stop_watch{CLOCK_MONOTONIC_COARSE}; + while (stop_watch.elapsedSeconds() < 3 * timeout.count()) + { + ASSERT_FALSE(mpp_tunnel_ptr->isReadyForWrite()); + } + GTEST_FAIL(); +} +catch (Exception & e) +{ + GTEST_ASSERT_EQ(e.message(), "0000_0001 is timeout"); +} } // namespace tests } // namespace DB diff --git a/dbms/src/Flash/Pipeline/Exec/PipelineExec.cpp b/dbms/src/Flash/Pipeline/Exec/PipelineExec.cpp index b694b49da98..22e408eea74 100644 --- a/dbms/src/Flash/Pipeline/Exec/PipelineExec.cpp +++ b/dbms/src/Flash/Pipeline/Exec/PipelineExec.cpp @@ -17,12 +17,28 @@ namespace DB { +void PipelineExec::executePrefix() +{ + sink_op->operatePrefix(); + for (auto it = transform_ops.rbegin(); it != transform_ops.rend(); ++it) + (*it)->operatePrefix(); + source_op->operatePrefix(); +} + +void PipelineExec::executeSuffix() +{ + sink_op->operateSuffix(); + for (auto it = transform_ops.rbegin(); it != transform_ops.rend(); ++it) + (*it)->operateSuffix(); + source_op->operateSuffix(); +} + OperatorStatus PipelineExec::execute() { auto op_status = executeImpl(); #ifndef NDEBUG // `NEED_INPUT` means that pipeline_exec need data to do the calculations and expect the next call to `execute`. - assertOperatorStatus(op_status, {OperatorStatus::NEED_INPUT}); + assertOperatorStatus(op_status, {OperatorStatus::FINISHED, OperatorStatus::NEED_INPUT}); #endif return op_status; } diff --git a/dbms/src/Flash/Pipeline/Exec/PipelineExec.h b/dbms/src/Flash/Pipeline/Exec/PipelineExec.h index eef9a75c557..0230cfe10da 100644 --- a/dbms/src/Flash/Pipeline/Exec/PipelineExec.h +++ b/dbms/src/Flash/Pipeline/Exec/PipelineExec.h @@ -35,6 +35,9 @@ class PipelineExec , sink_op(std::move(sink_op_)) {} + void executePrefix(); + void executeSuffix(); + OperatorStatus execute(); OperatorStatus await(); diff --git a/dbms/src/Flash/Pipeline/Pipeline.cpp b/dbms/src/Flash/Pipeline/Pipeline.cpp index 0421c1cd4c0..2d8da185230 100644 --- a/dbms/src/Flash/Pipeline/Pipeline.cpp +++ b/dbms/src/Flash/Pipeline/Pipeline.cpp @@ -140,7 +140,6 @@ bool Pipeline::isSupported(const tipb::DAGRequest & dag_request) case tipb::ExecType::TypeSelection: case tipb::ExecType::TypeLimit: case tipb::ExecType::TypeTopN: - // Only support mock table_scan/exchange_sender/exchange_receiver in test mode now. case tipb::ExecType::TypeTableScan: case tipb::ExecType::TypeExchangeSender: case tipb::ExecType::TypeExchangeReceiver: diff --git a/dbms/src/Flash/Pipeline/Schedule/Tasks/EventTask.cpp b/dbms/src/Flash/Pipeline/Schedule/Tasks/EventTask.cpp index f6e87a63a39..142750746e6 100644 --- a/dbms/src/Flash/Pipeline/Schedule/Tasks/EventTask.cpp +++ b/dbms/src/Flash/Pipeline/Schedule/Tasks/EventTask.cpp @@ -34,6 +34,21 @@ EventTask::~EventTask() event.reset(); } +void EventTask::finalize() +{ + try + { + bool tmp = false; + if (finalized.compare_exchange_strong(tmp, true)) + finalizeImpl(); + } + catch (...) + { + // ignore exception from finalizeImpl. + // TODO add log here. + } +} + ExecTaskStatus EventTask::executeImpl() { return doTaskAction([&] { return doExecuteImpl(); }); diff --git a/dbms/src/Flash/Pipeline/Schedule/Tasks/EventTask.h b/dbms/src/Flash/Pipeline/Schedule/Tasks/EventTask.h index 3d0d2a846f4..6b201df94b3 100644 --- a/dbms/src/Flash/Pipeline/Schedule/Tasks/EventTask.h +++ b/dbms/src/Flash/Pipeline/Schedule/Tasks/EventTask.h @@ -40,7 +40,8 @@ class EventTask : public Task virtual ExecTaskStatus doAwaitImpl() { return ExecTaskStatus::RUNNING; }; // Used to release held resources, just like `Event::finishImpl`. - virtual void finalize(){}; + void finalize(); + virtual void finalizeImpl(){}; private: template @@ -74,6 +75,7 @@ class EventTask : public Task private: PipelineExecutorStatus & exec_status; EventPtr event; + std::atomic_bool finalized{false}; }; } // namespace DB diff --git a/dbms/src/Flash/Pipeline/Schedule/Tasks/PipelineTask.cpp b/dbms/src/Flash/Pipeline/Schedule/Tasks/PipelineTask.cpp index b9c1b4dfdaa..12a3ee4e591 100644 --- a/dbms/src/Flash/Pipeline/Schedule/Tasks/PipelineTask.cpp +++ b/dbms/src/Flash/Pipeline/Schedule/Tasks/PipelineTask.cpp @@ -28,10 +28,13 @@ PipelineTask::PipelineTask( , pipeline_exec(std::move(pipeline_exec_)) { assert(pipeline_exec); + pipeline_exec->executePrefix(); } -void PipelineTask::finalize() +void PipelineTask::finalizeImpl() { + assert(pipeline_exec); + pipeline_exec->executeSuffix(); pipeline_exec.reset(); } diff --git a/dbms/src/Flash/Pipeline/Schedule/Tasks/PipelineTask.h b/dbms/src/Flash/Pipeline/Schedule/Tasks/PipelineTask.h index ec12fe48f63..2758723e087 100644 --- a/dbms/src/Flash/Pipeline/Schedule/Tasks/PipelineTask.h +++ b/dbms/src/Flash/Pipeline/Schedule/Tasks/PipelineTask.h @@ -33,7 +33,7 @@ class PipelineTask : public EventTask ExecTaskStatus doAwaitImpl() override; - void finalize() override; + void finalizeImpl() override; private: PipelineExecPtr pipeline_exec; diff --git a/dbms/src/Flash/Planner/Plans/PhysicalExchangeReceiver.cpp b/dbms/src/Flash/Planner/Plans/PhysicalExchangeReceiver.cpp index 1dd99163426..7afb65a71f3 100644 --- a/dbms/src/Flash/Planner/Plans/PhysicalExchangeReceiver.cpp +++ b/dbms/src/Flash/Planner/Plans/PhysicalExchangeReceiver.cpp @@ -17,10 +17,12 @@ #include #include #include +#include #include #include #include #include +#include #include #include @@ -77,16 +79,35 @@ void PhysicalExchangeReceiver::buildBlockInputStreamImpl(DAGPipeline & pipeline, for (size_t i = 0; i < stream_count; ++i) { - BlockInputStreamPtr stream = std::make_shared(mpp_exchange_receiver, - log->identifier(), - execId(), - /*stream_id=*/enable_fine_grained_shuffle ? i : 0); + BlockInputStreamPtr stream = std::make_shared( + mpp_exchange_receiver, + log->identifier(), + execId(), + /*stream_id=*/enable_fine_grained_shuffle ? i : 0); exchange_receiver_io_input_streams.push_back(stream); stream->setExtraInfo(extra_info); pipeline.streams.push_back(stream); } } +void PhysicalExchangeReceiver::buildPipelineExec(PipelineExecGroupBuilder & group_builder, Context & /*context*/, size_t concurrency) +{ + // TODO support fine grained shuffle. + const bool enable_fine_grained_shuffle = enableFineGrainedShuffle(mpp_exchange_receiver->getFineGrainedShuffleStreamCount()); + RUNTIME_CHECK(!enable_fine_grained_shuffle); + + // TODO choose a more reasonable concurrency. + group_builder.init(concurrency); + group_builder.transform([&](auto & builder) { + builder.setSourceOp(std::make_unique( + group_builder.exec_status, + mpp_exchange_receiver, + /*stream_id=*/0, + log->identifier(), + execId())); + }); +} + void PhysicalExchangeReceiver::finalize(const Names & parent_require) { FinalizeHelper::checkSchemaContainsParentRequire(schema, parent_require); diff --git a/dbms/src/Flash/Planner/Plans/PhysicalExchangeReceiver.h b/dbms/src/Flash/Planner/Plans/PhysicalExchangeReceiver.h index 52fe31ab61f..f73084d4244 100644 --- a/dbms/src/Flash/Planner/Plans/PhysicalExchangeReceiver.h +++ b/dbms/src/Flash/Planner/Plans/PhysicalExchangeReceiver.h @@ -45,6 +45,8 @@ class PhysicalExchangeReceiver : public PhysicalLeaf return mpp_exchange_receiver->getSourceNum(); } + void buildPipelineExec(PipelineExecGroupBuilder & group_builder, Context & /*context*/, size_t /*concurrency*/) override; + private: void buildBlockInputStreamImpl(DAGPipeline & pipeline, Context & context, size_t max_streams) override; diff --git a/dbms/src/Flash/Planner/Plans/PhysicalExchangeSender.cpp b/dbms/src/Flash/Planner/Plans/PhysicalExchangeSender.cpp index a887d777188..8c06dd5c1ac 100644 --- a/dbms/src/Flash/Planner/Plans/PhysicalExchangeSender.cpp +++ b/dbms/src/Flash/Planner/Plans/PhysicalExchangeSender.cpp @@ -19,9 +19,10 @@ #include #include #include +#include #include #include - +#include namespace DB { @@ -59,8 +60,6 @@ void PhysicalExchangeSender::buildBlockInputStreamImpl(DAGPipeline & pipeline, C auto & dag_context = *context.getDAGContext(); restoreConcurrency(pipeline, dag_context.final_concurrency, log); - RUNTIME_ASSERT(dag_context.isMPPTask() && dag_context.tunnel_set != nullptr, log, "exchange_sender only run in MPP"); - String extra_info; if (fine_grained_shuffle.enable()) { @@ -71,7 +70,6 @@ void PhysicalExchangeSender::buildBlockInputStreamImpl(DAGPipeline & pipeline, C pipeline.transform([&](auto & stream) { // construct writer std::unique_ptr response_writer = newMPPExchangeWriter( - dag_context.tunnel_set, partition_col_ids, partition_col_collators, exchange_type, @@ -82,12 +80,38 @@ void PhysicalExchangeSender::buildBlockInputStreamImpl(DAGPipeline & pipeline, C fine_grained_shuffle.stream_count, fine_grained_shuffle.batch_size, compression_mode, - context.getSettingsRef().batch_send_min_limit_compression); + context.getSettingsRef().batch_send_min_limit_compression, + log->identifier()); stream = std::make_shared(stream, std::move(response_writer), log->identifier()); stream->setExtraInfo(extra_info); }); } +void PhysicalExchangeSender::buildPipelineExec(PipelineExecGroupBuilder & group_builder, Context & context, size_t /*concurrency*/) +{ + // TODO support fine grained shuffle + RUNTIME_CHECK(!fine_grained_shuffle.enable()); + + group_builder.transform([&](auto & builder) { + // construct writer + std::unique_ptr response_writer = newMPPExchangeWriter( + partition_col_ids, + partition_col_collators, + exchange_type, + context.getSettingsRef().dag_records_per_chunk, + context.getSettingsRef().batch_send_min_limit, + *context.getDAGContext(), + fine_grained_shuffle.enable(), + fine_grained_shuffle.stream_count, + fine_grained_shuffle.batch_size, + compression_mode, + context.getSettingsRef().batch_send_min_limit_compression, + log->identifier(), + /*is_async=*/true); + builder.setSinkOp(std::make_unique(group_builder.exec_status, std::move(response_writer), log->identifier())); + }); +} + void PhysicalExchangeSender::finalize(const Names & parent_require) { child->finalize(parent_require); diff --git a/dbms/src/Flash/Planner/Plans/PhysicalExchangeSender.h b/dbms/src/Flash/Planner/Plans/PhysicalExchangeSender.h index 88fc4b7773b..121238cc6eb 100644 --- a/dbms/src/Flash/Planner/Plans/PhysicalExchangeSender.h +++ b/dbms/src/Flash/Planner/Plans/PhysicalExchangeSender.h @@ -53,6 +53,8 @@ class PhysicalExchangeSender : public PhysicalUnary const Block & getSampleBlock() const override; + void buildPipelineExec(PipelineExecGroupBuilder & group_builder, Context & context, size_t /*concurrency*/) override; + private: void buildBlockInputStreamImpl(DAGPipeline & pipeline, Context & context, size_t max_streams) override; diff --git a/dbms/src/Flash/executeQuery.cpp b/dbms/src/Flash/executeQuery.cpp index 6bd82a45ab1..102d4bd3a7a 100644 --- a/dbms/src/Flash/executeQuery.cpp +++ b/dbms/src/Flash/executeQuery.cpp @@ -153,8 +153,8 @@ QueryExecutorPtr executeAsBlockIO(Context & context, bool internal) QueryExecutorPtr queryExecute(Context & context, bool internal) { - // now only support pipeline model in executor/interpreter test. - if ((context.isExecutorTest() || context.isInterpreterTest()) + // now only support pipeline model in test mode. + if (context.isTest() && context.getSettingsRef().enable_planner && context.getSettingsRef().enable_pipeline) { diff --git a/dbms/src/Flash/tests/gtest_compute_server.cpp b/dbms/src/Flash/tests/gtest_compute_server.cpp index 8a81a5650dd..bf6d5b15806 100644 --- a/dbms/src/Flash/tests/gtest_compute_server.cpp +++ b/dbms/src/Flash/tests/gtest_compute_server.cpp @@ -102,9 +102,118 @@ class ComputeServerRunner : public DB::tests::MPPTaskTestUtils }; +#define WRAP_FOR_SERVER_TEST_BEGIN \ + std::vector pipeline_bools{false, true}; \ + for (auto enable_pipeline : pipeline_bools) \ + { \ + enablePipeline(enable_pipeline); + +#define WRAP_FOR_SERVER_TEST_END \ + } + +TEST_F(ComputeServerRunner, simpleExchange) +try +{ + std::vector::FieldType>> s1_col(10000); + for (size_t i = 0; i < s1_col.size(); ++i) + s1_col[i] = i; + auto expected_cols = {toNullableVec("s1", s1_col)}; + context.addMockTable( + {"test_db", "big_table"}, + {{"s1", TiDB::TP::TypeLong}}, + expected_cols); + + context.context.setSetting("max_block_size", Field(static_cast(100))); + + WRAP_FOR_SERVER_TEST_BEGIN + // For PassThrough and Broadcast, use only one server for testing, as multiple servers will double the result size. + { + startServers(1); + { + std::vector expected_strings = { + R"( +exchange_sender_2 | type:PassThrough, {<0, Long>} + project_1 | {<0, Long>} + table_scan_0 | {<0, Long>})"}; + ASSERT_MPPTASK_EQUAL_PLAN_AND_RESULT( + context + .scan("test_db", "big_table") + .project({"s1"}), + expected_strings, + expected_cols); + } + { + std::vector expected_strings = { + R"( +exchange_sender_1 | type:PassThrough, {<0, Long>} + table_scan_0 | {<0, Long>})", + R"( +exchange_sender_4 | type:PassThrough, {<0, Long>} + project_3 | {<0, Long>} + exchange_receiver_2 | type:PassThrough, {<0, Long>})"}; + ASSERT_MPPTASK_EQUAL_PLAN_AND_RESULT( + context + .scan("test_db", "big_table") + .exchangeSender(tipb::ExchangeType::PassThrough) + .exchangeReceiver("recv", {{"s1", TiDB::TP::TypeLong}}) + .project({"s1"}), + expected_strings, + expected_cols); + } + { + std::vector expected_strings = { + R"( +exchange_sender_1 | type:Broadcast, {<0, Long>} + table_scan_0 | {<0, Long>})", + R"( +exchange_sender_4 | type:PassThrough, {<0, Long>} + project_3 | {<0, Long>} + exchange_receiver_2 | type:Broadcast, {<0, Long>})"}; + ASSERT_MPPTASK_EQUAL_PLAN_AND_RESULT( + context + .scan("test_db", "big_table") + .exchangeSender(tipb::ExchangeType::Broadcast) + .exchangeReceiver("recv", {{"s1", TiDB::TP::TypeLong}}) + .project({"s1"}), + expected_strings, + expected_cols); + } + } + // For Hash, multiple servers will not double the result. + { + startServers(2); + std::vector expected_strings = { + R"( +exchange_sender_1 | type:Hash, {<0, Long>} + table_scan_0 | {<0, Long>})", + R"( +exchange_sender_1 | type:Hash, {<0, Long>} + table_scan_0 | {<0, Long>})", + R"( +exchange_sender_4 | type:PassThrough, {<0, Long>} + project_3 | {<0, Long>} + exchange_receiver_2 | type:Hash, {<0, Long>})", + R"( +exchange_sender_4 | type:PassThrough, {<0, Long>} + project_3 | {<0, Long>} + exchange_receiver_2 | type:Hash, {<0, Long>})"}; + ASSERT_MPPTASK_EQUAL_PLAN_AND_RESULT( + context + .scan("test_db", "big_table") + .exchangeSender(tipb::ExchangeType::Hash) + .exchangeReceiver("recv", {{"s1", TiDB::TP::TypeLong}}) + .project({"s1"}), + expected_strings, + expected_cols); + } + WRAP_FOR_SERVER_TEST_END +} +CATCH + TEST_F(ComputeServerRunner, runAggTasks) try { + WRAP_FOR_SERVER_TEST_BEGIN startServers(4); { std::vector expected_strings = { @@ -179,12 +288,14 @@ exchange_sender_3 | type:PassThrough, {<0, Long>} ASSERT_DAGREQUEST_EQAUL(expected_strings[i], tasks[i].dag_request); } } + WRAP_FOR_SERVER_TEST_END } CATCH TEST_F(ComputeServerRunner, runJoinTasks) try { + WRAP_FOR_SERVER_TEST_BEGIN startServers(3); { auto expected_cols = { @@ -223,7 +334,7 @@ try .scan("test_db", "l_table") .join(context.scan("test_db", "r_table"), tipb::JoinType::TypeLeftOuterJoin, {col("join_c")}), expected_strings, - expect_cols); + expected_cols); } { @@ -249,12 +360,14 @@ try ASSERT_DAGREQUEST_EQAUL(expected_strings[i], tasks[i].dag_request); } } + WRAP_FOR_SERVER_TEST_END } CATCH TEST_F(ComputeServerRunner, runJoinThenAggTasks) try { + WRAP_FOR_SERVER_TEST_BEGIN startServers(3); { std::vector expected_strings = { @@ -309,7 +422,7 @@ try .aggregation({Max(col("l_table.s"))}, {col("l_table.s")}) .project({col("max(l_table.s)"), col("l_table.s")}), expected_strings, - expect_cols); + expected_cols); } { @@ -343,12 +456,14 @@ try ASSERT_DAGREQUEST_EQAUL(expected_strings[i], tasks[i].dag_request); } } + WRAP_FOR_SERVER_TEST_END } CATCH TEST_F(ComputeServerRunner, aggWithColumnPrune) try { + WRAP_FOR_SERVER_TEST_BEGIN startServers(3); context.addMockTable( @@ -390,12 +505,14 @@ try ASSERT_COLUMNS_EQ_UR(expected_cols, buildAndExecuteMPPTasks(request)); } } + WRAP_FOR_SERVER_TEST_END } CATCH TEST_F(ComputeServerRunner, cancelAggTasks) try { + WRAP_FOR_SERVER_TEST_BEGIN startServers(4); { auto [query_id, res] = prepareMPPStreams(context @@ -406,12 +523,14 @@ try MockComputeServerManager::instance().cancelQuery(query_id); EXPECT_TRUE(assertQueryCancelled(query_id)); } + WRAP_FOR_SERVER_TEST_END } CATCH TEST_F(ComputeServerRunner, cancelJoinTasks) try { + WRAP_FOR_SERVER_TEST_BEGIN startServers(4); { auto [query_id, res] = prepareMPPStreams(context @@ -421,12 +540,14 @@ try MockComputeServerManager::instance().cancelQuery(query_id); EXPECT_TRUE(assertQueryCancelled(query_id)); } + WRAP_FOR_SERVER_TEST_END } CATCH TEST_F(ComputeServerRunner, cancelJoinThenAggTasks) try { + WRAP_FOR_SERVER_TEST_BEGIN startServers(4); { auto [query_id, _] = prepareMPPStreams(context @@ -438,12 +559,14 @@ try MockComputeServerManager::instance().cancelQuery(query_id); EXPECT_TRUE(assertQueryCancelled(query_id)); } + WRAP_FOR_SERVER_TEST_END } CATCH TEST_F(ComputeServerRunner, multipleQuery) try { + WRAP_FOR_SERVER_TEST_BEGIN startServers(4); { auto [query_id1, res1] = prepareMPPStreams(context @@ -481,6 +604,7 @@ try EXPECT_TRUE(assertQueryCancelled(query_id)); } } + WRAP_FOR_SERVER_TEST_END } CATCH @@ -488,6 +612,7 @@ TEST_F(ComputeServerRunner, runCoprocessor) try { // In coprocessor test, we only need to start 1 server. + WRAP_FOR_SERVER_TEST_BEGIN startServers(1); { auto request = context @@ -499,12 +624,14 @@ try toNullableVec({{"apple", {}, "banana"}})}; ASSERT_COLUMNS_EQ_UR(expected_cols, executeCoprocessorTask(request)); } + WRAP_FOR_SERVER_TEST_END } CATCH TEST_F(ComputeServerRunner, runFineGrainedShuffleJoinTest) try { + WRAP_FOR_SERVER_TEST_BEGIN startServers(3); constexpr size_t join_type_num = 7; constexpr tipb::JoinType join_types[join_type_num] = { @@ -522,7 +649,6 @@ try for (auto join_type : join_types) { - std::cout << "JoinType: " << static_cast(join_type) << std::endl; auto properties = DB::tests::getDAGPropertiesForTest(serverNum()); auto request = context .scan("test_db", "l_table_2") @@ -538,12 +664,14 @@ try const auto actual_cols = executeMPPTasks(tasks, properties, MockComputeServerManager::instance().getServerConfigMap()); ASSERT_COLUMNS_EQ_UR(expected_cols, actual_cols); } + WRAP_FOR_SERVER_TEST_END } CATCH TEST_F(ComputeServerRunner, runFineGrainedShuffleAggTest) try { + WRAP_FOR_SERVER_TEST_BEGIN startServers(3); // fine-grained shuffle is enabled. constexpr uint64_t enable = 8; @@ -562,8 +690,12 @@ try const auto actual_cols = executeMPPTasks(tasks, properties, MockComputeServerManager::instance().getServerConfigMap()); ASSERT_COLUMNS_EQ_UR(expected_cols, actual_cols); } + WRAP_FOR_SERVER_TEST_END } CATCH +#undef WRAP_FOR_SERVER_TEST_BEGIN +#undef WRAP_FOR_SERVER_TEST_END + } // namespace tests } // namespace DB diff --git a/dbms/src/Operators/ExchangeReceiverSourceOp.cpp b/dbms/src/Operators/ExchangeReceiverSourceOp.cpp new file mode 100644 index 00000000000..d07baffcaac --- /dev/null +++ b/dbms/src/Operators/ExchangeReceiverSourceOp.cpp @@ -0,0 +1,108 @@ +// Copyright 2023 PingCAP, Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +namespace DB +{ +void ExchangeReceiverSourceOp::operateSuffix() +{ + LOG_INFO(log, "finish read {} rows from exchange", total_rows); +} + +Block ExchangeReceiverSourceOp::popFromBlockQueue() +{ + assert(!block_queue.empty()); + Block block = std::move(block_queue.front()); + block_queue.pop(); + return block; +} + +OperatorStatus ExchangeReceiverSourceOp::readImpl(Block & block) +{ + if (!block_queue.empty()) + { + block = popFromBlockQueue(); + return OperatorStatus::HAS_OUTPUT; + } + + while (true) + { + assert(block_queue.empty()); + auto await_status = awaitImpl(); + if (await_status == OperatorStatus::HAS_OUTPUT) + { + assert(recv_res); + assert(recv_res->recv_status != ReceiveStatus::empty); + auto result = exchange_receiver->toExchangeReceiveResult( + *recv_res, + block_queue, + header, + decoder_ptr); + recv_res.reset(); + + if (result.meet_error) + { + LOG_WARNING(log, "exchange receiver meets error: {}", result.error_msg); + throw Exception(result.error_msg); + } + if (result.resp != nullptr && result.resp->has_error()) + { + LOG_WARNING(log, "exchange receiver meets error: {}", result.resp->error().DebugString()); + throw Exception(result.resp->error().DebugString()); + } + if (result.eof) + { + LOG_DEBUG(log, "exchange receiver meets eof"); + return OperatorStatus::HAS_OUTPUT; + } + + const auto & decode_detail = result.decode_detail; + total_rows += decode_detail.rows; + LOG_TRACE( + log, + "recv {} rows from exchange receiver for {}, total recv row num: {}", + decode_detail.rows, + result.req_info, + total_rows); + + if (decode_detail.rows <= 0) + continue; + + block = popFromBlockQueue(); + return OperatorStatus::HAS_OUTPUT; + } + assert(!recv_res); + return await_status; + } +} + +OperatorStatus ExchangeReceiverSourceOp::awaitImpl() +{ + if (!block_queue.empty() || recv_res) + return OperatorStatus::HAS_OUTPUT; + recv_res.emplace(exchange_receiver->nonBlockingReceive(stream_id)); + switch (recv_res->recv_status) + { + case ReceiveStatus::ok: + assert(recv_res->recv_msg); + return OperatorStatus::HAS_OUTPUT; + case ReceiveStatus::empty: + recv_res.reset(); + return OperatorStatus::WAITING; + case ReceiveStatus::eof: + return OperatorStatus::HAS_OUTPUT; + } +} +} // namespace DB diff --git a/dbms/src/Operators/ExchangeReceiverSourceOp.h b/dbms/src/Operators/ExchangeReceiverSourceOp.h new file mode 100644 index 00000000000..b6bfb4ad39f --- /dev/null +++ b/dbms/src/Operators/ExchangeReceiverSourceOp.h @@ -0,0 +1,69 @@ +// Copyright 2023 PingCAP, Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include + +namespace DB +{ +class ExchangeReceiverSourceOp : public SourceOp +{ +public: + ExchangeReceiverSourceOp( + PipelineExecutorStatus & exec_status_, + const std::shared_ptr & exchange_receiver_, + size_t stream_id_, + const String & req_id, + const String & executor_id) + : SourceOp(exec_status_) + , exchange_receiver(exchange_receiver_) + , stream_id(stream_id_) + , log(Logger::get(req_id, executor_id)) + { + setHeader(Block(getColumnWithTypeAndName(toNamesAndTypes(exchange_receiver->getOutputSchema())))); + decoder_ptr = std::make_unique(getHeader(), 8192); + } + + String getName() const override + { + return "ExchangeReceiverSourceOp"; + } + +protected: + OperatorStatus readImpl(Block & block) override; + + OperatorStatus awaitImpl() override; + + void operateSuffix() override; + +private: + Block popFromBlockQueue(); + +private: + // TODO support ConnectionProfileInfo. + // TODO support RemoteExecutionSummary. + std::shared_ptr exchange_receiver; + std::unique_ptr decoder_ptr; + uint64_t total_rows{}; + std::queue block_queue; + std::optional recv_res; + + size_t stream_id; + const LoggerPtr log; +}; +} // namespace DB diff --git a/dbms/src/Operators/ExchangeSenderSinkOp.cpp b/dbms/src/Operators/ExchangeSenderSinkOp.cpp new file mode 100644 index 00000000000..83c90c2d32a --- /dev/null +++ b/dbms/src/Operators/ExchangeSenderSinkOp.cpp @@ -0,0 +1,73 @@ +// Copyright 2023 PingCAP, Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +namespace DB +{ +namespace FailPoints +{ +extern const char hang_in_execution[]; +extern const char exception_during_mpp_non_root_task_run[]; +extern const char exception_during_mpp_root_task_run[]; +} // namespace FailPoints + +void ExchangeSenderSinkOp::operatePrefix() +{ + writer->prepare(getHeader()); +} + +void ExchangeSenderSinkOp::operateSuffix() +{ + LOG_DEBUG(log, "finish write with {} rows", total_rows); +} + +OperatorStatus ExchangeSenderSinkOp::writeImpl(Block && block) +{ +#ifndef NDEBUG + FAIL_POINT_PAUSE(FailPoints::hang_in_execution); + if (writer->dagContext().isRootMPPTask()) + { + FAIL_POINT_TRIGGER_EXCEPTION(FailPoints::exception_during_mpp_root_task_run); + } + else + { + FAIL_POINT_TRIGGER_EXCEPTION(FailPoints::exception_during_mpp_non_root_task_run); + } +#endif + + if (!block) + { + writer->flush(); + return OperatorStatus::FINISHED; + } + + total_rows += block.rows(); + writer->write(block); + return OperatorStatus::NEED_INPUT; +} + +OperatorStatus ExchangeSenderSinkOp::prepareImpl() +{ + return writer->isReadyForWrite() ? OperatorStatus::NEED_INPUT : OperatorStatus::WAITING; +} + +OperatorStatus ExchangeSenderSinkOp::awaitImpl() +{ + return writer->isReadyForWrite() ? OperatorStatus::NEED_INPUT : OperatorStatus::WAITING; +} + +} // namespace DB diff --git a/dbms/src/Operators/ExchangeSenderSinkOp.h b/dbms/src/Operators/ExchangeSenderSinkOp.h new file mode 100644 index 00000000000..037390f64e3 --- /dev/null +++ b/dbms/src/Operators/ExchangeSenderSinkOp.h @@ -0,0 +1,55 @@ +// Copyright 2023 PingCAP, Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +namespace DB +{ +class ExchangeSenderSinkOp : public SinkOp +{ +public: + ExchangeSenderSinkOp( + PipelineExecutorStatus & exec_status_, + std::unique_ptr && writer, + const String & req_id) + : SinkOp(exec_status_) + , writer(std::move(writer)) + , log(Logger::get(req_id)) + { + } + + String getName() const override + { + return "ExchangeSenderSinkOp"; + } + + void operatePrefix() override; + void operateSuffix() override; + + OperatorStatus writeImpl(Block && block) override; + + OperatorStatus prepareImpl() override; + + OperatorStatus awaitImpl() override; + +private: + std::unique_ptr writer; + const LoggerPtr log; + size_t total_rows = 0; +}; +} // namespace DB diff --git a/dbms/src/Operators/Operator.cpp b/dbms/src/Operators/Operator.cpp index 88426df7032..dfc8237bdd2 100644 --- a/dbms/src/Operators/Operator.cpp +++ b/dbms/src/Operators/Operator.cpp @@ -108,7 +108,7 @@ OperatorStatus SinkOp::write(Block && block) // TODO collect operator profile info here. auto op_status = writeImpl(std::move(block)); #ifndef NDEBUG - assertOperatorStatus(op_status, {OperatorStatus::NEED_INPUT}); + assertOperatorStatus(op_status, {OperatorStatus::FINISHED, OperatorStatus::NEED_INPUT}); #endif return op_status; } diff --git a/dbms/src/Operators/Operator.h b/dbms/src/Operators/Operator.h index cba082151ea..18181cc12a7 100644 --- a/dbms/src/Operators/Operator.h +++ b/dbms/src/Operators/Operator.h @@ -22,13 +22,15 @@ namespace DB { /** * All interfaces of the operator may return the following state. - * - finish status and waiting status can be returned in all method of operator. + * - finish status will only be returned by sink op, because only sink can tell if the pipeline has actually finished. + * - cancel status and waiting status can be returned in all method of operator. * - operator may return a different running status depending on the method. */ enum class OperatorStatus { /// finish status FINISHED, + /// cancel status CANCELLED, /// waiting status WAITING, @@ -56,6 +58,10 @@ class Operator OperatorStatus await(); virtual OperatorStatus awaitImpl() { throw Exception("Unsupport"); } + // These two methods are used to set state, log and etc, and should not perform calculation logic. + virtual void operatePrefix() {} + virtual void operateSuffix() {} + virtual String getName() const = 0; /** Get data structure of the operator in a form of "header" block (it is also called "sample block"). diff --git a/dbms/src/Operators/OperatorHelper.cpp b/dbms/src/Operators/OperatorHelper.cpp index 8ed5a4e6dde..8dac56f9837 100644 --- a/dbms/src/Operators/OperatorHelper.cpp +++ b/dbms/src/Operators/OperatorHelper.cpp @@ -25,8 +25,7 @@ void assertOperatorStatus( { switch (status) { - // finish status and waiting status can be returned in all method of operator. - case OperatorStatus::FINISHED: + // cancel status and waiting status can be returned in all method of operator. case OperatorStatus::CANCELLED: case OperatorStatus::WAITING: return; diff --git a/dbms/src/TestUtils/MPPTaskTestUtils.cpp b/dbms/src/TestUtils/MPPTaskTestUtils.cpp index 427ac99cd5e..38d0a0e0bbe 100644 --- a/dbms/src/TestUtils/MPPTaskTestUtils.cpp +++ b/dbms/src/TestUtils/MPPTaskTestUtils.cpp @@ -62,7 +62,7 @@ void MPPTaskTestUtils::startServers(size_t server_num_) { MockComputeServerManager::instance().addServer(MockServerAddrGenerator::instance().nextAddr()); // Currently, we simply add a context and don't care about destruct it. - TiFlashTestEnv::addGlobalContext(); + TiFlashTestEnv::addGlobalContext(context.context.getSettings()); TiFlashTestEnv::getGlobalContext(i + test_meta.context_idx).setMPPTest(); } diff --git a/dbms/src/TestUtils/MPPTaskTestUtils.h b/dbms/src/TestUtils/MPPTaskTestUtils.h index 75330ed0c6d..a33b5528714 100644 --- a/dbms/src/TestUtils/MPPTaskTestUtils.h +++ b/dbms/src/TestUtils/MPPTaskTestUtils.h @@ -70,9 +70,9 @@ class MPPTaskTestUtils : public ExecutorTest void TearDown() override; - static void startServers(); + void startServers(); - static void startServers(size_t server_num_); + void startServers(size_t server_num_); static size_t serverNum(); // run mpp tasks which are ready to cancel, the return value is the start_ts of query. @@ -94,7 +94,7 @@ class MPPTaskTestUtils : public ExecutorTest static MPPTestMeta test_meta; }; -#define ASSERT_MPPTASK_EQUAL(tasks, properties, expect_cols) \ +#define ASSERT_MPPTASK_EQUAL(tasks, properties, expected_cols) \ do \ { \ TiFlashTestEnv::getGlobalContext().setMPPTest(); \ @@ -123,14 +123,15 @@ class MPPTaskTestUtils : public ExecutorTest TiFlashTestEnv::getGlobalContext(i).setMPPTest(); \ auto tasks = (builder).buildMPPTasks(context, properties); \ size_t task_size = tasks.size(); \ + ASSERT_EQ(task_size, (expected_strings).size()); \ for (size_t i = 0; i < task_size; ++i) \ { \ ASSERT_DAGREQUEST_EQAUL((expected_strings)[i], tasks[i].dag_request); \ } \ ASSERT_MPPTASK_EQUAL_WITH_SERVER_NUM( \ - builder, \ - properties, \ - expect_cols); \ + (builder), \ + (properties), \ + (expected_cols)); \ } while (0) } // namespace DB::tests diff --git a/dbms/src/TestUtils/TiFlashTestEnv.cpp b/dbms/src/TestUtils/TiFlashTestEnv.cpp index cb6bf3f8ec8..d8971b0f2a4 100644 --- a/dbms/src/TestUtils/TiFlashTestEnv.cpp +++ b/dbms/src/TestUtils/TiFlashTestEnv.cpp @@ -87,10 +87,10 @@ void TiFlashTestEnv::tryRemovePath(const std::string & path, bool recreate) void TiFlashTestEnv::initializeGlobalContext(Strings testdata_path, PageStorageRunMode ps_run_mode, uint64_t bg_thread_count) { - addGlobalContext(testdata_path, ps_run_mode, bg_thread_count); + addGlobalContext(DB::Settings(), testdata_path, ps_run_mode, bg_thread_count); } -void TiFlashTestEnv::addGlobalContext(Strings testdata_path, PageStorageRunMode ps_run_mode, uint64_t bg_thread_count) +void TiFlashTestEnv::addGlobalContext(const DB::Settings & settings_, Strings testdata_path, PageStorageRunMode ps_run_mode, uint64_t bg_thread_count) { // set itself as global context auto global_context = std::make_shared(DB::Context::createGlobal()); @@ -104,6 +104,7 @@ void TiFlashTestEnv::addGlobalContext(Strings testdata_path, PageStorageRunMode global_context->initializeFileProvider(key_manager, false); // initialize background & blockable background thread pool + global_context->setSettings(settings_); Settings & settings = global_context->getSettingsRef(); global_context->initializeBackgroundPool(bg_thread_count == 0 ? settings.background_pool_size.get() : bg_thread_count); global_context->initializeBlockableBackgroundPool(bg_thread_count == 0 ? settings.background_pool_size.get() : bg_thread_count); diff --git a/dbms/src/TestUtils/TiFlashTestEnv.h b/dbms/src/TestUtils/TiFlashTestEnv.h index 1de5c3b9467..a9b1a4880ce 100644 --- a/dbms/src/TestUtils/TiFlashTestEnv.h +++ b/dbms/src/TestUtils/TiFlashTestEnv.h @@ -78,7 +78,7 @@ class TiFlashTestEnv static Context getContext(const DB::Settings & settings = DB::Settings(), Strings testdata_path = {}); static void initializeGlobalContext(Strings testdata_path = {}, PageStorageRunMode ps_run_mode = PageStorageRunMode::ONLY_V3, uint64_t bg_thread_count = 2); - static void addGlobalContext(Strings testdata_path = {}, PageStorageRunMode ps_run_mode = PageStorageRunMode::ONLY_V3, uint64_t bg_thread_count = 2); + static void addGlobalContext(const DB::Settings & settings_ = DB::Settings(), Strings testdata_path = {}, PageStorageRunMode ps_run_mode = PageStorageRunMode::ONLY_V3, uint64_t bg_thread_count = 2); static Context & getGlobalContext() { return *global_contexts[0]; } static Context & getGlobalContext(int idx) { return *global_contexts[idx]; } static int globalContextSize() { return global_contexts.size(); } diff --git a/dbms/src/TestUtils/mockExecutor.cpp b/dbms/src/TestUtils/mockExecutor.cpp index d1be7e1c17f..b40a7450e18 100644 --- a/dbms/src/TestUtils/mockExecutor.cpp +++ b/dbms/src/TestUtils/mockExecutor.cpp @@ -193,7 +193,7 @@ DAGRequestBuilder & DAGRequestBuilder::buildExchangeReceiver(const String & exch schema.push_back({exchange_name + "." + column.first, info}); } - root = mock::compileExchangeReceiver(getExecutorIndex(), schema, fine_grained_shuffle_stream_count); + root = mock::compileExchangeReceiver(getExecutorIndex(), schema, fine_grained_shuffle_stream_count, std::static_pointer_cast(root)); return *this; }