Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decouple user_kernel and device_tag #8529

Merged
merged 10 commits into from
Jul 1, 2022
1 change: 0 additions & 1 deletion oneflow/core/framework/infer_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ class InferContext {
virtual int32_t output_size(const std::string& arg_name) const = 0;
virtual const std::string& op_name() const = 0;
virtual const std::string& op_type_name() const = 0;
virtual const std::string& device_tag() const = 0;
virtual const std::string& op_loc() const = 0;

template<typename T>
Expand Down
3 changes: 0 additions & 3 deletions oneflow/core/framework/op_expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,6 @@ class UserOpExprInferContext : public user_op::InferContext {
const std::function<TensorMeta*(int32_t)>& TensorMeta4OutputIndex)
: user_op_expr_(user_op_expr),
composed_attrs_(attrs, user_op_expr->base_attrs()),
device_tag_(device_tag),
tensor_meta4input_index_(TensorMeta4InputIndex),
tensor_meta4output_index_(TensorMeta4OutputIndex) {
loc_ = DispatchFrame::get_str();
Expand Down Expand Up @@ -302,7 +301,6 @@ class UserOpExprInferContext : public user_op::InferContext {
}
const std::string& op_name() const override { return user_op_expr_->op_name(); }
const std::string& op_type_name() const override { return user_op_expr_->op_type_name(); }
const std::string& device_tag() const override { return device_tag_; }
const std::string& op_loc() const override { return loc_; }

private:
Expand All @@ -312,7 +310,6 @@ class UserOpExprInferContext : public user_op::InferContext {
}
const UserOpExpr* user_op_expr_;
const ComposedAttrMap composed_attrs_;
const std::string& device_tag_;
const std::function<const TensorMeta*(int32_t)>& tensor_meta4input_index_;
const std::function<TensorMeta*(int32_t)>& tensor_meta4output_index_;
std::string loc_;
Expand Down
4 changes: 0 additions & 4 deletions oneflow/core/framework/op_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ class KernelInitContext {
}
const std::string& op_name() const { return user_op_conf().op_name(); }
const std::string& op_type_name() const { return user_op_conf().op_type_name(); }
const std::string& device_tag() const { return user_op_conf().op_conf().device_tag(); }
const OperatorConf& op_conf() const { return user_op_conf().op_conf(); }

template<typename T>
Expand Down Expand Up @@ -133,7 +132,6 @@ class KernelCacheContext {
}
const std::string& op_name() const { return user_op_conf().op_name(); }
const std::string& op_type_name() const { return user_op_conf().op_type_name(); }
const std::string& device_tag() const { return user_op_conf().op_conf().device_tag(); }
const OperatorConf& op_conf() const { return user_op_conf().op_conf(); }

template<typename T>
Expand Down Expand Up @@ -188,7 +186,6 @@ class KernelInferContext {
}
const std::string& op_name() const { return user_op_conf().op_name(); }
const std::string& op_type_name() const { return user_op_conf().op_type_name(); }
const std::string& device_tag() const { return user_op_conf().op_conf().device_tag(); }

template<typename T>
const T& Attr(const std::string& attr_name) const {
Expand Down Expand Up @@ -249,7 +246,6 @@ class KernelComputeContext {
}
const std::string& op_name() const { return user_op_conf().op_name(); }
const std::string& op_type_name() const { return user_op_conf().op_type_name(); }
const std::string& device_tag() const { return user_op_conf().op_conf().device_tag(); }

template<typename T>
const T& Attr(const std::string& attr_name) const {
Expand Down
1 change: 0 additions & 1 deletion oneflow/core/framework/user_op_kernel_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ class KernelRegContext {
virtual ~KernelRegContext() = default;

virtual DeviceType device_type() const = 0;
virtual const std::string& device_tag() const = 0;
virtual const ParallelContext& parallel_ctx() const = 0;
virtual const TensorDesc* TensorDesc4ArgNameAndIndex(const std::string&, int32_t) const = 0;

Expand Down
8 changes: 2 additions & 6 deletions oneflow/core/kernel/user_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ class UserKernelBaseContext {
};
InitInOrOut(kernel_conf.op_attribute().op_conf().user_conf().input(), &inputs_);
InitInOrOut(kernel_conf.op_attribute().op_conf().user_conf().output(), &outputs_);
device_tag_ = kernel_conf.op_attribute().op_conf().device_tag();
device_type_ = CHECK_JUST(DeviceType4DeviceTag(device_tag_));
device_type_ =
CHECK_JUST(DeviceType4DeviceTag(kernel_conf.op_attribute().op_conf().device_tag()));
parallel_ctx_ = kernel_conf.parallel_ctx();
for (const auto& pair : kernel_conf.user_conf().bn_in_op2blob_desc()) {
arg2bn_and_tensor_desc_.emplace(
Expand All @@ -89,7 +89,6 @@ class UserKernelBaseContext {
~UserKernelBaseContext() = default;

DeviceType device_type() const { return device_type_; }
const std::string& device_tag() const { return device_tag_; }
const ParallelContext& parallel_ctx() const { return parallel_ctx_; }
const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name,
int32_t index) const {
Expand All @@ -108,7 +107,6 @@ class UserKernelBaseContext {
ArgVec inputs_;
ArgVec outputs_;
DeviceType device_type_;
std::string device_tag_;
ParallelContext parallel_ctx_;
};

Expand Down Expand Up @@ -357,7 +355,6 @@ class UserKernelOpInferContext : public user_op::InferContext {
}
const std::string& op_name() const override { return user_op_conf().op_name(); }
const std::string& op_type_name() const override { return user_op_conf().op_type_name(); }
const std::string& device_tag() const override { return user_op_conf().op_conf().device_tag(); }
const std::string& op_loc() const override { return user_op_conf_.op_conf().loc(); }

private:
Expand Down Expand Up @@ -581,7 +578,6 @@ class UserKernelRegContext final : public user_op::KernelRegContext {
~UserKernelRegContext() = default;

DeviceType device_type() const override { return base_ctx_.device_type(); }
const std::string& device_tag() const override { return base_ctx_.device_tag(); }
const ParallelContext& parallel_ctx() const override { return base_ctx_.parallel_ctx(); }
const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name,
int32_t index) const override {
Expand Down
8 changes: 3 additions & 5 deletions oneflow/core/operator/user_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,7 @@ class UserOpKernelRegContext final : public user_op::KernelRegContext {
const auto& op_conf = user_op->op_conf();
CHECK(op_conf.has_user_conf());

device_tag_ = op_conf.device_tag();
device_type_ = CHECK_JUST(DeviceType4DeviceTag(device_tag_));
device_type_ = CHECK_JUST(DeviceType4DeviceTag(op_conf.device_tag()));
parallel_ctx_ = parallel_ctx;

auto InitInOrOut = [&](const PbMap<std::string, UserOpConf::ListString>& arg_map,
Expand Down Expand Up @@ -97,7 +96,7 @@ class UserOpKernelRegContext final : public user_op::KernelRegContext {
~UserOpKernelRegContext() = default;

DeviceType device_type() const override { return device_type_; }
const std::string& device_tag() const override { return device_tag_; }

const ParallelContext& parallel_ctx() const override { return *parallel_ctx_; }
const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name,
int32_t index) const override {
Expand All @@ -120,7 +119,6 @@ class UserOpKernelRegContext final : public user_op::KernelRegContext {
ArgVec inputs_;
ArgVec outputs_;
DeviceType device_type_;
std::string device_tag_;
const ParallelContext* parallel_ctx_;
HashMap<std::pair<std::string, int32_t>, user_op::NaiveTensorDesc> arg2tensor_desc_;
};
Expand Down Expand Up @@ -270,7 +268,7 @@ class UserOpInferContext final : public user_op::InferContext {
}
const std::string& op_name() const override { return user_op_conf().op_name(); }
const std::string& op_type_name() const override { return user_op_conf().op_type_name(); }
const std::string& device_tag() const override { return user_op_conf().op_conf().device_tag(); }

const std::string& op_loc() const override { return op_->op_loc(); }

private:
Expand Down
5 changes: 2 additions & 3 deletions oneflow/user/kernels/arg_where_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,8 @@ struct SwitchUtil {
#undef SWITCH_ENTRY
};

template<DeviceType device_type>
size_t InferTempStorageBytesSize(user_op::InferContext* ctx) {
const std::string& device_tag = ctx->device_tag();
DeviceType device_type = CHECK_JUST(DeviceType4DeviceTag(device_tag));
const Shape& input_shape = ctx->InputShape("input", 0);
if (input_shape.NumAxes() == 0) { return 0; }
const DataType& input_dtype = ctx->InputDType("input", 0);
Expand All @@ -92,7 +91,7 @@ size_t InferTempStorageBytesSize(user_op::InferContext* ctx) {
&& (user_op::HobDataType("input", 0) == GetDataType<itype>::value) \
&& (user_op::HobDataType("output", 0) == GetDataType<otype>::value) \
&& (user_op::HobDataType("output_size", 0) == GetDataType<otype>::value)) \
.SetInferTmpSizeFn(InferTempStorageBytesSize);
.SetInferTmpSizeFn(InferTempStorageBytesSize<device>);

#define REGISTER_ARG_WHERE_KERNEL_WITH_DTYPE_PAIR(device, itype_pair, otype_pair) \
REGISTER_ARG_WHERE_KERNEL(device, OF_PP_PAIR_FIRST(itype_pair), OF_PP_PAIR_FIRST(otype_pair))
Expand Down
33 changes: 12 additions & 21 deletions oneflow/user/kernels/stateful_opkernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,24 +124,20 @@ class ZeroCopyBaseContextHelper {

class UserKernelBaseContextHelper final : public ZeroCopyBaseContextHelper {
public:
UserKernelBaseContextHelper(const std::string& device_tag,
UserKernelBaseContextHelper(DeviceType device_type,
const std::shared_ptr<const ArgTuple>& input_arg_tuple,
const std::shared_ptr<const ArgTuple>& output_arg_tuple)
: ZeroCopyBaseContextHelper(input_arg_tuple, output_arg_tuple),
device_tag_(device_tag),
device_type_(CHECK_JUST(DeviceType4DeviceTag(device_tag_))) {}
: ZeroCopyBaseContextHelper(input_arg_tuple, output_arg_tuple), device_type_(device_type) {}

~UserKernelBaseContextHelper() = default;

DeviceType device_type() const { return device_type_; }
const std::string& device_tag() const { return device_tag_; }
const JobDesc& job_desc() const {
UNIMPLEMENTED();
return *(const JobDesc*)nullptr;
}

private:
const std::string device_tag_;
const DeviceType device_type_;
};

Expand Down Expand Up @@ -275,7 +271,6 @@ class UserOpInferContextHelper final {
}
const std::string& op_name() const { return user_op_conf().op_name(); }
const std::string& op_type_name() const { return user_op_conf().op_type_name(); }
const std::string& device_tag() const { return user_op_conf().op_conf().device_tag(); }
const std::string& op_loc() const { return user_op_conf_->op_conf().loc(); }

const user_op::UserOpConfWrapper& user_op_conf() const { return *user_op_conf_; }
Expand Down Expand Up @@ -392,7 +387,6 @@ class UserOpInferContext : public user_op::InferContext {
}
const std::string& op_name() const override { return helper_->op_name(); }
const std::string& op_type_name() const override { return helper_->op_type_name(); }
const std::string& device_tag() const override { return helper_->device_tag(); }
const std::string& op_loc() const override { return helper_->op_loc(); }

private:
Expand All @@ -407,12 +401,12 @@ class UserOpInferContext : public user_op::InferContext {

class UserKernelComputeContextHelper final {
public:
UserKernelComputeContextHelper(const std::string& device_tag,
UserKernelComputeContextHelper(DeviceType device_type,
const user_op::UserOpConfWrapper* user_op_conf,
const std::shared_ptr<const ArgTuple>& input_arg_tuple,
const std::shared_ptr<const ArgTuple>& output_arg_tuple)
: user_op_conf_(user_op_conf),
base_ctx_helper_(device_tag, input_arg_tuple, output_arg_tuple) {}
base_ctx_helper_(device_type, input_arg_tuple, output_arg_tuple) {}

~UserKernelComputeContextHelper() = default;

Expand Down Expand Up @@ -493,16 +487,14 @@ class UserKernelComputeContext final : public user_op::KernelComputeContext {

class UserKernelRegContextHelper final {
public:
UserKernelRegContextHelper(const std::string& device_tag,
const user_op::UserOpConfWrapper* user_op_conf,
UserKernelRegContextHelper(DeviceType device_type, const user_op::UserOpConfWrapper* user_op_conf,
const std::shared_ptr<const ArgTuple>& input_arg_tuple,
const std::shared_ptr<const ArgTuple>& output_arg_tuple)
: user_op_conf_(user_op_conf),
base_ctx_helper_(device_tag, input_arg_tuple, output_arg_tuple) {}
base_ctx_helper_(device_type, input_arg_tuple, output_arg_tuple) {}
~UserKernelRegContextHelper() = default;

DeviceType device_type() const { return base_ctx_helper_.device_type(); }
const std::string& device_tag() const { return base_ctx_helper_.device_tag(); }
const ParallelContext& parallel_ctx(eager::CallContext* call_ctx) const {
return base_ctx_helper_.parallel_ctx(call_ctx);
}
Expand Down Expand Up @@ -533,7 +525,6 @@ class UserKernelRegContext final : public user_op::KernelRegContext {
~UserKernelRegContext() = default;

DeviceType device_type() const override { return helper_->device_type(); }
const std::string& device_tag() const override { return helper_->device_tag(); }
const ParallelContext& parallel_ctx() const override { return helper_->parallel_ctx(call_ctx_); }
const user_op::TensorDesc* TensorDesc4ArgNameAndIndex(const std::string& arg_name,
int32_t index) const override {
Expand All @@ -558,12 +549,12 @@ class UserKernelRegContext final : public user_op::KernelRegContext {

class UserKernelInitAndCacheContextHelper final {
public:
UserKernelInitAndCacheContextHelper(const std::string& device_tag,
UserKernelInitAndCacheContextHelper(DeviceType device_type,
const user_op::UserOpConfWrapper* user_op_conf,
const std::shared_ptr<const ArgTuple>& input_arg_tuple,
const std::shared_ptr<const ArgTuple>& output_arg_tuple)
: user_op_conf_(user_op_conf),
base_ctx_helper_(device_tag, input_arg_tuple, output_arg_tuple) {}
base_ctx_helper_(device_type, input_arg_tuple, output_arg_tuple) {}

~UserKernelInitAndCacheContextHelper() = default;

Expand Down Expand Up @@ -751,18 +742,18 @@ Maybe<void> InitTensorTupleIndexes4Bns(const std::shared_ptr<const OperatorConf>
opkernel->output_arg_tuple_ = output_arg_tuple;
opkernel->need_check_mem_case_ = true;

const std::string& device_tag = op_conf->device_tag();
const DeviceType device_type = CHECK_JUST(DeviceType4DeviceTag(op_conf->device_tag()));
const user_op::UserOpConfWrapper* user_op_conf = opkernel->user_op_conf_.get();
opkernel->op_infer_ctx_helper_.reset(
new UserOpInferContextHelper(user_op_conf, input_arg_tuple, output_arg_tuple));

opkernel->init_and_cache_ctx_helper_.reset(new UserKernelInitAndCacheContextHelper(
opkernel->op_conf_->device_tag(), opkernel->user_op_conf_.get(), opkernel->input_arg_tuple_,
device_type, opkernel->user_op_conf_.get(), opkernel->input_arg_tuple_,
opkernel->output_arg_tuple_));
opkernel->compute_ctx_helper_.reset(new UserKernelComputeContextHelper(
device_tag, user_op_conf, input_arg_tuple, output_arg_tuple));
device_type, user_op_conf, input_arg_tuple, output_arg_tuple));
opkernel->reg_ctx_helper_.reset(
new UserKernelRegContextHelper(device_tag, user_op_conf, input_arg_tuple, output_arg_tuple));
new UserKernelRegContextHelper(device_type, user_op_conf, input_arg_tuple, output_arg_tuple));
const auto* op_reg_val =
user_op::UserOpRegistryMgr::Get().GetOpRegistryResult(user_op_conf->op_type_name());
CHECK_NOTNULL_OR_RETURN(op_reg_val);
Expand Down