Skip to content

Commit

Permalink
fix tmp_buffer in stateful local kernel (#4757)
Browse files Browse the repository at this point in the history
* fix tmp_buffer in stateful local kernel

Signed-off-by: daquexian <daquexian566@gmail.com>

* move tmp_buffer_view_ != nullptr first

Signed-off-by: daquexian <daquexian566@gmail.com>

* Move tmp_buffer check to bottom

* add specific tmp_buffer not null check

Co-authored-by: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com>
Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
  • Loading branch information
3 people authored Apr 28, 2021
1 parent de87e6f commit fed4e81
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 11 deletions.
37 changes: 27 additions & 10 deletions oneflow/user/kernels/stateful_local_opkernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ void InitArgName2BnIndex2TensorTupleIndex(

ZeroCopyBaseContext::ZeroCopyBaseContext(const ArgVec* indexed_input_pairs,
const ArgVec* indexed_output_pairs)
: ZeroCopyBaseContext(indexed_input_pairs, indexed_output_pairs, nullptr) {}

ZeroCopyBaseContext::ZeroCopyBaseContext(const ArgVec* indexed_input_pairs,
const ArgVec* indexed_output_pairs,
vm::EagerBlobObject* tmp_buffer)
: indexed_input_pairs_(indexed_input_pairs), indexed_output_pairs_(indexed_output_pairs) {
InitArgName2BnIndex2TensorTupleIndex(indexed_input_pairs,
&arg_name2bn_index2input_tensor_tuple_index_);
Expand All @@ -78,6 +83,9 @@ ZeroCopyBaseContext::ZeroCopyBaseContext(const ArgVec* indexed_input_pairs,
output_tensor_desc_views_.push_back(std::make_unique<EagerBlobObjectTensorDescView>(
[this, i]() -> vm::EagerBlobObject* { return output_tensors_->at(i).get(); }));
}
if (tmp_buffer != nullptr) {
tmp_buffer_view_.reset(new EagerBlobObjectTensorView([tmp_buffer]() { return tmp_buffer; }));
}
}

void ZeroCopyBaseContext::Update(EagerBlobObjectList inputs, EagerBlobObjectList outputs) {
Expand All @@ -101,16 +109,24 @@ user_op::Tensor* ZeroCopyBaseContext::Tensor4ArgNameAndIndex(const std::string&
if (i >= 0) { return input_tensor_views_.at(i).get(); }
i = TryGetTensorTupleIndex(arg_name2bn_index2output_tensor_tuple_index_, arg_name, index);
if (i >= 0) { return output_tensor_views_.at(i).get(); }
if (arg_name == "tmp_buffer" && index == 0) { return CHECK_NOTNULL(tmp_buffer_view_.get()); }
LOG(FATAL) << "Arg (" << arg_name << "," << index << ") is not found";
return nullptr;
}

LocalUserKernelBaseContext::LocalUserKernelBaseContext(const std::string& device_tag,
const ArgVec* indexed_input_pairs,
const ArgVec* indexed_output_pairs)
: ZeroCopyBaseContext(indexed_input_pairs, indexed_output_pairs),
: LocalUserKernelBaseContext(device_tag, indexed_input_pairs, indexed_output_pairs, nullptr) {}

LocalUserKernelBaseContext::LocalUserKernelBaseContext(const std::string& device_tag,
const ArgVec* indexed_input_pairs,
const ArgVec* indexed_output_pairs,
vm::EagerBlobObject* tmp_buffer)
: ZeroCopyBaseContext(indexed_input_pairs, indexed_output_pairs, tmp_buffer),
device_tag_(device_tag),
device_type_(CHECK_JUST(DeviceType4DeviceTag(device_tag_))) {}
device_type_(CHECK_JUST(DeviceType4DeviceTag(device_tag_))),
tmp_buffer_(tmp_buffer) {}

class LocalUserKernelRegContext final : public user_op::KernelRegContext {
public:
Expand Down Expand Up @@ -228,10 +244,10 @@ void LocalUserOpInferContext::Update(EagerBlobObjectList inputs, EagerBlobObject
LocalUserKernelComputeContext::LocalUserKernelComputeContext(
DeviceCtx* device_ctx, const std::string& device_tag,
const user_op::UserOpConfWrapper* user_op_conf, const ArgVec* index_input_pairs,
const ArgVec* indexed_output_pairs)
const ArgVec* indexed_output_pairs, vm::EagerBlobObject* tmp_buffer)
: user_op_conf_(user_op_conf),
device_ctx_(device_ctx),
base_ctx_(device_tag, index_input_pairs, indexed_output_pairs) {}
base_ctx_(device_tag, index_input_pairs, indexed_output_pairs, tmp_buffer) {}

void LocalUserKernelComputeContext::Update(EagerBlobObjectList inputs, EagerBlobObjectList outputs,
DeviceCtx* device_ctx) {
Expand Down Expand Up @@ -318,12 +334,16 @@ Maybe<void> InitTensorTupleIndexes4Bns(const std::shared_ptr<const OperatorConf>
opkernel->indexed_output_pairs_ = indexed_output_pairs;
opkernel->need_check_mem_case_ = true;

opkernel->tmp_blob_object_.reset(
new vm::EagerBlobObject(opkernel->mem_case_, std::make_shared<Shape>(), DataType::kChar,
std::make_shared<vm::TensorBuffer>()));

const std::string& device_tag = op_conf->device_tag();
opkernel->op_infer_ctx_.reset(new LocalUserOpInferContext(
opkernel->user_op_conf_.get(), indexed_input_pairs.get(), indexed_output_pairs.get()));
opkernel->compute_ctx_.reset(
new LocalUserKernelComputeContext(nullptr, device_tag, opkernel->user_op_conf_.get(),
indexed_input_pairs.get(), indexed_output_pairs.get()));
opkernel->compute_ctx_.reset(new LocalUserKernelComputeContext(
nullptr, device_tag, opkernel->user_op_conf_.get(), indexed_input_pairs.get(),
indexed_output_pairs.get(), opkernel->mut_temp_blob_object()));
opkernel->create_ctx_.reset(new LocalUserKernelCreateContext(opkernel->user_op_conf_.get()));
opkernel->reg_ctx_.reset(new LocalUserKernelRegContext(device_tag, opkernel->user_op_conf_.get(),
indexed_input_pairs.get(),
Expand All @@ -343,9 +363,6 @@ Maybe<void> InitTensorTupleIndexes4Bns(const std::shared_ptr<const OperatorConf>
&opkernel->input_tuple_indexes4const_ibns_, &opkernel->input_tuple_indexes4mut_ibns_,
&opkernel->output_tuple_indexes4mut_obns_, &opkernel->output_tuple_indexes4mut2_obns_));

opkernel->tmp_blob_object_.reset(
new vm::EagerBlobObject(opkernel->mem_case_, std::make_shared<Shape>(), DataType::kChar,
std::make_shared<vm::TensorBuffer>()));
opkernel->infer_local_dep_object_ = std::make_shared<VmLocalDepObject>(parallel_desc);
opkernel->compute_local_dep_object_ = std::make_shared<VmLocalDepObject>(parallel_desc);
return opkernel;
Expand Down
9 changes: 8 additions & 1 deletion oneflow/user/kernels/stateful_local_opkernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ class EagerBlobObjectTensorDescView final : public user_op::TensorDesc {
class ZeroCopyBaseContext {
public:
ZeroCopyBaseContext(const ArgVec* indexed_input_pairs, const ArgVec* indexed_output_pairs);
ZeroCopyBaseContext(const ArgVec* indexed_input_pairs, const ArgVec* indexed_output_pairs,
vm::EagerBlobObject* tmp_buffer);

user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name, int32_t index) const;

Expand All @@ -118,6 +120,7 @@ class ZeroCopyBaseContext {
std::vector<std::unique_ptr<EagerBlobObjectTensorView>> output_tensor_views_;
std::vector<std::unique_ptr<EagerBlobObjectTensorDescView>> input_tensor_desc_views_;
std::vector<std::unique_ptr<EagerBlobObjectTensorDescView>> output_tensor_desc_views_;
std::unique_ptr<EagerBlobObjectTensorView> tmp_buffer_view_;
EagerBlobObjectList input_tensors_;
EagerBlobObjectList output_tensors_;
};
Expand All @@ -126,6 +129,8 @@ class LocalUserKernelBaseContext : public ZeroCopyBaseContext {
public:
LocalUserKernelBaseContext(const std::string& device_tag, const ArgVec* indexed_input_pairs,
const ArgVec* indexed_output_pairs);
LocalUserKernelBaseContext(const std::string& device_tag, const ArgVec* indexed_input_pairs,
const ArgVec* indexed_output_pairs, vm::EagerBlobObject* tmp_buffer);
~LocalUserKernelBaseContext() = default;

DeviceType device_type() const { return device_type_; }
Expand All @@ -135,6 +140,7 @@ class LocalUserKernelBaseContext : public ZeroCopyBaseContext {
private:
const std::string device_tag_;
const DeviceType device_type_;
vm::EagerBlobObject* tmp_buffer_;
};

class LocalUserOpInferContext : public user_op::InferContext {
Expand Down Expand Up @@ -196,7 +202,8 @@ class LocalUserKernelComputeContext final : public user_op::KernelComputeContext
explicit LocalUserKernelComputeContext(DeviceCtx* device_ctx, const std::string& device_tag,
const user_op::UserOpConfWrapper* user_op_conf,
const ArgVec* index_input_pairs,
const ArgVec* indexed_output_pairs);
const ArgVec* indexed_output_pairs,
vm::EagerBlobObject* tmp_buffer);
~LocalUserKernelComputeContext() = default;

const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name,
Expand Down

0 comments on commit fed4e81

Please sign in to comment.