From fed4e8101cce11a748a07567df453b79d3cbab0c Mon Sep 17 00:00:00 2001 From: daquexian Date: Wed, 28 Apr 2021 15:33:51 +0800 Subject: [PATCH] fix tmp_buffer in stateful local kernel (#4757) * fix tmp_buffer in stateful local kernel Signed-off-by: daquexian * move tmp_buffer_view_ != nullptr first Signed-off-by: daquexian * Move tmp_buffer check to bottom * add specific tmp_buffer not null check Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com> Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> --- .../user/kernels/stateful_local_opkernel.cpp | 37 ++++++++++++++----- .../user/kernels/stateful_local_opkernel.h | 9 ++++- 2 files changed, 35 insertions(+), 11 deletions(-) diff --git a/oneflow/user/kernels/stateful_local_opkernel.cpp b/oneflow/user/kernels/stateful_local_opkernel.cpp index b2e9ded71b3..77d60ca808e 100644 --- a/oneflow/user/kernels/stateful_local_opkernel.cpp +++ b/oneflow/user/kernels/stateful_local_opkernel.cpp @@ -61,6 +61,11 @@ void InitArgName2BnIndex2TensorTupleIndex( ZeroCopyBaseContext::ZeroCopyBaseContext(const ArgVec* indexed_input_pairs, const ArgVec* indexed_output_pairs) + : ZeroCopyBaseContext(indexed_input_pairs, indexed_output_pairs, nullptr) {} + +ZeroCopyBaseContext::ZeroCopyBaseContext(const ArgVec* indexed_input_pairs, + const ArgVec* indexed_output_pairs, + vm::EagerBlobObject* tmp_buffer) : indexed_input_pairs_(indexed_input_pairs), indexed_output_pairs_(indexed_output_pairs) { InitArgName2BnIndex2TensorTupleIndex(indexed_input_pairs, &arg_name2bn_index2input_tensor_tuple_index_); @@ -78,6 +83,9 @@ ZeroCopyBaseContext::ZeroCopyBaseContext(const ArgVec* indexed_input_pairs, output_tensor_desc_views_.push_back(std::make_unique( [this, i]() -> vm::EagerBlobObject* { return output_tensors_->at(i).get(); })); } + if (tmp_buffer != nullptr) { + tmp_buffer_view_.reset(new EagerBlobObjectTensorView([tmp_buffer]() { return tmp_buffer; })); + } } void ZeroCopyBaseContext::Update(EagerBlobObjectList inputs, EagerBlobObjectList outputs) { @@ -101,6 +109,7 @@ user_op::Tensor* ZeroCopyBaseContext::Tensor4ArgNameAndIndex(const std::string& if (i >= 0) { return input_tensor_views_.at(i).get(); } i = TryGetTensorTupleIndex(arg_name2bn_index2output_tensor_tuple_index_, arg_name, index); if (i >= 0) { return output_tensor_views_.at(i).get(); } + if (arg_name == "tmp_buffer" && index == 0) { return CHECK_NOTNULL(tmp_buffer_view_.get()); } LOG(FATAL) << "Arg (" << arg_name << "," << index << ") is not found"; return nullptr; } @@ -108,9 +117,16 @@ user_op::Tensor* ZeroCopyBaseContext::Tensor4ArgNameAndIndex(const std::string& LocalUserKernelBaseContext::LocalUserKernelBaseContext(const std::string& device_tag, const ArgVec* indexed_input_pairs, const ArgVec* indexed_output_pairs) - : ZeroCopyBaseContext(indexed_input_pairs, indexed_output_pairs), + : LocalUserKernelBaseContext(device_tag, indexed_input_pairs, indexed_output_pairs, nullptr) {} + +LocalUserKernelBaseContext::LocalUserKernelBaseContext(const std::string& device_tag, + const ArgVec* indexed_input_pairs, + const ArgVec* indexed_output_pairs, + vm::EagerBlobObject* tmp_buffer) + : ZeroCopyBaseContext(indexed_input_pairs, indexed_output_pairs, tmp_buffer), device_tag_(device_tag), - device_type_(CHECK_JUST(DeviceType4DeviceTag(device_tag_))) {} + device_type_(CHECK_JUST(DeviceType4DeviceTag(device_tag_))), + tmp_buffer_(tmp_buffer) {} class LocalUserKernelRegContext final : public user_op::KernelRegContext { public: @@ -228,10 +244,10 @@ void LocalUserOpInferContext::Update(EagerBlobObjectList inputs, EagerBlobObject LocalUserKernelComputeContext::LocalUserKernelComputeContext( DeviceCtx* device_ctx, const std::string& device_tag, const user_op::UserOpConfWrapper* user_op_conf, const ArgVec* index_input_pairs, - const ArgVec* indexed_output_pairs) + const ArgVec* indexed_output_pairs, vm::EagerBlobObject* tmp_buffer) : user_op_conf_(user_op_conf), device_ctx_(device_ctx), - base_ctx_(device_tag, index_input_pairs, indexed_output_pairs) {} + base_ctx_(device_tag, index_input_pairs, indexed_output_pairs, tmp_buffer) {} void LocalUserKernelComputeContext::Update(EagerBlobObjectList inputs, EagerBlobObjectList outputs, DeviceCtx* device_ctx) { @@ -318,12 +334,16 @@ Maybe InitTensorTupleIndexes4Bns(const std::shared_ptr opkernel->indexed_output_pairs_ = indexed_output_pairs; opkernel->need_check_mem_case_ = true; + opkernel->tmp_blob_object_.reset( + new vm::EagerBlobObject(opkernel->mem_case_, std::make_shared(), DataType::kChar, + std::make_shared())); + const std::string& device_tag = op_conf->device_tag(); opkernel->op_infer_ctx_.reset(new LocalUserOpInferContext( opkernel->user_op_conf_.get(), indexed_input_pairs.get(), indexed_output_pairs.get())); - opkernel->compute_ctx_.reset( - new LocalUserKernelComputeContext(nullptr, device_tag, opkernel->user_op_conf_.get(), - indexed_input_pairs.get(), indexed_output_pairs.get())); + opkernel->compute_ctx_.reset(new LocalUserKernelComputeContext( + nullptr, device_tag, opkernel->user_op_conf_.get(), indexed_input_pairs.get(), + indexed_output_pairs.get(), opkernel->mut_temp_blob_object())); opkernel->create_ctx_.reset(new LocalUserKernelCreateContext(opkernel->user_op_conf_.get())); opkernel->reg_ctx_.reset(new LocalUserKernelRegContext(device_tag, opkernel->user_op_conf_.get(), indexed_input_pairs.get(), @@ -343,9 +363,6 @@ Maybe InitTensorTupleIndexes4Bns(const std::shared_ptr &opkernel->input_tuple_indexes4const_ibns_, &opkernel->input_tuple_indexes4mut_ibns_, &opkernel->output_tuple_indexes4mut_obns_, &opkernel->output_tuple_indexes4mut2_obns_)); - opkernel->tmp_blob_object_.reset( - new vm::EagerBlobObject(opkernel->mem_case_, std::make_shared(), DataType::kChar, - std::make_shared())); opkernel->infer_local_dep_object_ = std::make_shared(parallel_desc); opkernel->compute_local_dep_object_ = std::make_shared(parallel_desc); return opkernel; diff --git a/oneflow/user/kernels/stateful_local_opkernel.h b/oneflow/user/kernels/stateful_local_opkernel.h index db3f349d5cb..bb83715c3ce 100644 --- a/oneflow/user/kernels/stateful_local_opkernel.h +++ b/oneflow/user/kernels/stateful_local_opkernel.h @@ -99,6 +99,8 @@ class EagerBlobObjectTensorDescView final : public user_op::TensorDesc { class ZeroCopyBaseContext { public: ZeroCopyBaseContext(const ArgVec* indexed_input_pairs, const ArgVec* indexed_output_pairs); + ZeroCopyBaseContext(const ArgVec* indexed_input_pairs, const ArgVec* indexed_output_pairs, + vm::EagerBlobObject* tmp_buffer); user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) const; @@ -118,6 +120,7 @@ class ZeroCopyBaseContext { std::vector> output_tensor_views_; std::vector> input_tensor_desc_views_; std::vector> output_tensor_desc_views_; + std::unique_ptr tmp_buffer_view_; EagerBlobObjectList input_tensors_; EagerBlobObjectList output_tensors_; }; @@ -126,6 +129,8 @@ class LocalUserKernelBaseContext : public ZeroCopyBaseContext { public: LocalUserKernelBaseContext(const std::string& device_tag, const ArgVec* indexed_input_pairs, const ArgVec* indexed_output_pairs); + LocalUserKernelBaseContext(const std::string& device_tag, const ArgVec* indexed_input_pairs, + const ArgVec* indexed_output_pairs, vm::EagerBlobObject* tmp_buffer); ~LocalUserKernelBaseContext() = default; DeviceType device_type() const { return device_type_; } @@ -135,6 +140,7 @@ class LocalUserKernelBaseContext : public ZeroCopyBaseContext { private: const std::string device_tag_; const DeviceType device_type_; + vm::EagerBlobObject* tmp_buffer_; }; class LocalUserOpInferContext : public user_op::InferContext { @@ -196,7 +202,8 @@ class LocalUserKernelComputeContext final : public user_op::KernelComputeContext explicit LocalUserKernelComputeContext(DeviceCtx* device_ctx, const std::string& device_tag, const user_op::UserOpConfWrapper* user_op_conf, const ArgVec* index_input_pairs, - const ArgVec* indexed_output_pairs); + const ArgVec* indexed_output_pairs, + vm::EagerBlobObject* tmp_buffer); ~LocalUserKernelComputeContext() = default; const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name,