From 5af7d1beabd9375937b03ce769383742bdf9ee7f Mon Sep 17 00:00:00 2001 From: ZZK <42901638+MARD1NO@users.noreply.github.com> Date: Thu, 29 Apr 2021 11:44:52 +0800 Subject: [PATCH 1/3] fix bug about pos_weight (#4768) Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> --- oneflow/python/ops/nn_ops.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/oneflow/python/ops/nn_ops.py b/oneflow/python/ops/nn_ops.py index fb0cb6a7d6e..a449235cfe3 100644 --- a/oneflow/python/ops/nn_ops.py +++ b/oneflow/python/ops/nn_ops.py @@ -3848,13 +3848,6 @@ def bce_with_logits_loss_job(input: tp.Numpy.Placeholder(shape=(2, 3)), reduction ) - assert pos_weight.shape[0] == input.shape[-1], ( - "The length of `pos_weight` must be equal to the number of classes. " - "Found the length of pos_weight {} vs classes {}".format( - pos_weight.shape[0], input.shape[-1] - ) - ) - if name is None: name = id_util.UniqueStr("BCEWithLogitsLoss") @@ -3863,6 +3856,12 @@ def bce_with_logits_loss_job(input: tp.Numpy.Placeholder(shape=(2, 3)), _neg_max_val = flow.math.negative(_max_val) if pos_weight: + assert pos_weight.shape[0] == input.shape[-1], ( + "The length of `pos_weight` must be equal to the number of classes. " + "Found the length of pos_weight {} vs classes {}".format( + pos_weight.shape[0], input.shape[-1] + ) + ) _log_weight = ((pos_weight - 1) * target) + 1 _loss = (1 - target) * input + _log_weight * ( flow.math.log( From c402963faf906cd199b9438e896728a95f2c985c Mon Sep 17 00:00:00 2001 From: Yurui Li <32978179+poohRui@users.noreply.github.com> Date: Thu, 29 Apr 2021 12:28:45 +0800 Subject: [PATCH 2/3] Create default log under log dir (#4766) * move default log under log dir * remove default env log from proto * refine Co-authored-by: Shenghang Tsai Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> --- .gitignore | 1 - oneflow/core/job/env.proto | 1 - oneflow/core/job/env_global_objects_scope.cpp | 5 +++-- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/.gitignore b/.gitignore index a62edd04570..46c05ff0c17 100644 --- a/.gitignore +++ b/.gitignore @@ -7,7 +7,6 @@ /examples/**/oneflow /benchmark/**/oneflow log/ -default_physical_env_log/ *plan core.* *.pyc diff --git a/oneflow/core/job/env.proto b/oneflow/core/job/env.proto index 010686704da..09bd5fe367c 100644 --- a/oneflow/core/job/env.proto +++ b/oneflow/core/job/env.proto @@ -14,7 +14,6 @@ message CppLoggingConf { optional string log_dir = 1 [default = "./log"]; optional int32 logtostderr = 2 [default = 0]; optional int32 logbuflevel = 3 [default = -1]; - optional string default_physical_env_log_dir = 4 [default = "./default_physical_env_log"]; } message EnvProto { diff --git a/oneflow/core/job/env_global_objects_scope.cpp b/oneflow/core/job/env_global_objects_scope.cpp index 3b14b252278..7a0ff78a746 100644 --- a/oneflow/core/job/env_global_objects_scope.cpp +++ b/oneflow/core/job/env_global_objects_scope.cpp @@ -42,7 +42,7 @@ namespace { std::string LogDir(const std::string& log_dir) { char hostname[255]; CHECK_EQ(gethostname(hostname, sizeof(hostname)), 0); - std::string v = log_dir + "/" + std::string(hostname); + std::string v = JoinPath(log_dir, std::string(hostname)); return v; } @@ -50,7 +50,8 @@ void InitLogging(const CppLoggingConf& logging_conf, bool default_physical_env) if (!default_physical_env) { FLAGS_log_dir = LogDir(logging_conf.log_dir()); } else { - FLAGS_log_dir = LogDir(logging_conf.default_physical_env_log_dir()); + std::string default_env_log_path = JoinPath(logging_conf.log_dir(), "default_physical_env_log"); + FLAGS_log_dir = LogDir(default_env_log_path); } FLAGS_logtostderr = logging_conf.logtostderr(); FLAGS_logbuflevel = logging_conf.logbuflevel(); From 354652eef7f9f24fa761b970b0d82b213f32a9ee Mon Sep 17 00:00:00 2001 From: Houjiang Chen Date: Thu, 29 Apr 2021 13:49:42 +0800 Subject: [PATCH 3/3] Dev refactor attr value map (#4755) * Refactor * Draft * Refactor * Add MutableCfgAttrValueMap * implement AttrValueMap and ComposedAttrValueMap (#4767) * Attr value util (#4773) * implement AttrValueMap and ComposedAttrValueMap * AttrValueUtil::ToProtoAttrValue * Fix compile * Fix compilation * Rename AttrValueMap by AttrMap. Co-authored-by: lixinqi Co-authored-by: hjchen2 Co-authored-by: Li Xinqi Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> --- .../{attr_val_map.cpp => attr_map.cpp} | 15 ++- oneflow/api/python/framework/op_expr.cpp | 10 +- .../autograd/gradient_funcs/batch_gather.cpp | 6 +- .../core/autograd/gradient_funcs/bias_add.cpp | 4 +- .../gradient_funcs/broadcast_binary_ops.cpp | 4 +- oneflow/core/autograd/gradient_funcs/cast.cpp | 4 +- .../core/autograd/gradient_funcs/deconv.cpp | 4 +- .../core/autograd/gradient_funcs/default.cpp | 4 +- .../autograd/gradient_funcs/layer_norm.cpp | 6 +- .../autograd/gradient_funcs/normalization.cpp | 4 +- .../core/autograd/gradient_funcs/reshape.cpp | 2 +- .../autograd/gradient_funcs/split_like.cpp | 6 +- .../gradient_funcs/tensor_scalar_binary.cpp | 19 ++- oneflow/core/common/shape.cpp | 6 + oneflow/core/common/shape.h | 5 + .../local_call_opkernel_phy_instr_operand.h | 9 +- oneflow/core/framework/attr_map.cpp | 112 ++++++++++++++++++ oneflow/core/framework/attr_map.h | 100 ++++++++++++++++ .../core/framework/attr_value_accessor.cpp | 68 +++++++++++ oneflow/core/framework/attr_value_accessor.h | 14 +++ oneflow/core/framework/attr_value_map.cpp | 66 ----------- oneflow/core/framework/attr_value_map.h | 37 ------ .../core/framework/instructions_builder.cpp | 2 +- oneflow/core/framework/instructions_builder.h | 2 +- oneflow/core/framework/op_expr.cpp | 26 ++-- oneflow/core/framework/op_expr.h | 6 +- .../core/framework/op_expr_grad_function.h | 8 +- oneflow/core/framework/op_interpreter.h | 18 +-- .../eager_consistent_op_interpreter.cpp | 18 +-- .../eager_mirrored_op_interpreter.cpp | 24 ++-- .../op_interpreter/op_interpreter.cpp | 12 +- .../op_interpreter/op_interpreter_util.cpp | 8 +- .../op_interpreter/op_interpreter_util.h | 13 +- oneflow/core/framework/user_op_conf_trait.h | 13 +- oneflow/python/framework/op_expr_util.py | 2 +- .../user/kernels/stateful_local_opkernel.cpp | 8 +- .../user/kernels/stateful_local_opkernel.h | 4 +- 37 files changed, 436 insertions(+), 233 deletions(-) rename oneflow/api/python/framework/{attr_val_map.cpp => attr_map.cpp} (68%) create mode 100644 oneflow/core/framework/attr_map.cpp create mode 100644 oneflow/core/framework/attr_map.h delete mode 100644 oneflow/core/framework/attr_value_map.cpp delete mode 100644 oneflow/core/framework/attr_value_map.h diff --git a/oneflow/api/python/framework/attr_val_map.cpp b/oneflow/api/python/framework/attr_map.cpp similarity index 68% rename from oneflow/api/python/framework/attr_val_map.cpp rename to oneflow/api/python/framework/attr_map.cpp index 001dfca068d..a2214f2e46d 100644 --- a/oneflow/api/python/framework/attr_val_map.cpp +++ b/oneflow/api/python/framework/attr_map.cpp @@ -17,7 +17,7 @@ limitations under the License. #include #include "oneflow/api/python/of_api_registry.h" -#include "oneflow/core/framework/attr_value_map.h" +#include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/user_op_attr.cfg.h" namespace py = pybind11; @@ -25,21 +25,20 @@ namespace py = pybind11; namespace oneflow { ONEFLOW_API_PYBIND11_MODULE("", m) { - py::class_>(m, "AttrValueMap") + py::class_>(m, "MutableCfgAttrMap") .def(py::init<>()) .def("__setitem__", - [](AttrValueMap* m, const std::string& attr_name, + [](MutableCfgAttrMap* m, const std::string& attr_name, const std::shared_ptr& attr_value) { m->SetAttr(attr_name, attr_value).GetOrThrow(); }) .def("__getitem__", - [](const AttrValueMap& m, const std::string& attr_name) { - m.GetAttr(attr_name).GetOrThrow(); - }) + [](const MutableCfgAttrMap& m, const std::string& attr_name) { return m.at(attr_name); }) .def( - "__iter__", [](const AttrValueMap& m) { return py::make_iterator(m.begin(), m.end()); }, + "__iter__", + [](const MutableCfgAttrMap& m) { return py::make_iterator(m.begin(), m.end()); }, py::keep_alive<0, 1>()) - .def("__len__", [](const AttrValueMap& m) { return m.size(); }); + .def("__len__", [](const MutableCfgAttrMap& m) { return m.size(); }); } } // namespace oneflow diff --git a/oneflow/api/python/framework/op_expr.cpp b/oneflow/api/python/framework/op_expr.cpp index 2b02965cab1..9c6e392a9bb 100644 --- a/oneflow/api/python/framework/op_expr.cpp +++ b/oneflow/api/python/framework/op_expr.cpp @@ -17,7 +17,7 @@ limitations under the License. #include #include "oneflow/api/python/of_api_registry.h" #include "oneflow/core/common/protobuf.h" -#include "oneflow/core/framework/attr_value_map.h" +#include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/framework/op_interpreter.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" @@ -32,7 +32,7 @@ namespace oneflow { namespace { Maybe Interpret(const one::OpExpr& op, const one::TensorTuple& inputs, - const AttrValueMap& attrs) { + const AttrMap& attrs) { CHECK_EQ_OR_RETURN(op.input_size(), inputs.size()) << "The operation requires " << op.input_size() << " inputs, but " << inputs.size() << " is given."; @@ -44,7 +44,7 @@ Maybe Interpret(const one::OpExpr& op, const one::TensorTuple& Maybe Interpret(const one::OpExpr& op, const std::vector>& inputs, - const AttrValueMap& attrs) { + const AttrMap& attrs) { one::TensorTuple input_list(inputs.size()); for (int i = 0; i < inputs.size(); ++i) { input_list[i] = inputs[i]; } return JUST(Interpret(op, input_list, attrs)); @@ -76,11 +76,11 @@ ONEFLOW_API_PYBIND11_MODULE("one", m) { .def_property_readonly("output_size", &one::OpExpr::output_size) .def("apply", [](const one::OpExpr& op_expr, const std::vector>& inputs, - const AttrValueMap& attrs) { + const MutableCfgAttrMap& attrs) { return Interpret(op_expr, inputs, attrs).GetPtrOrThrow(); }) .def("apply", [](const one::OpExpr& op_expr, const one::TensorTuple& inputs, - const AttrValueMap& attrs) { + const MutableCfgAttrMap& attrs) { return Interpret(op_expr, inputs, attrs).GetPtrOrThrow(); }); diff --git a/oneflow/core/autograd/gradient_funcs/batch_gather.cpp b/oneflow/core/autograd/gradient_funcs/batch_gather.cpp index ed7c7a353b5..bfba2bd0c6c 100644 --- a/oneflow/core/autograd/gradient_funcs/batch_gather.cpp +++ b/oneflow/core/autograd/gradient_funcs/batch_gather.cpp @@ -31,7 +31,7 @@ class BatchGather : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(BatchGatherInterpState* ctx, const TensorTuple& inputs, - const TensorTuple& outputs, const AttrValueMap& attrs) const override; + const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const BatchGatherInterpState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; @@ -49,7 +49,7 @@ Maybe BatchGather::Init(const OpExpr& op) { } Maybe BatchGather::Capture(BatchGatherInterpState* ctx, const TensorTuple& inputs, - const TensorTuple& outputs, const AttrValueMap& attrs) const { + const TensorTuple& outputs, const AttrMap& attrs) const { ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } const auto& in_shape = inputs.at(0)->shape(); @@ -64,7 +64,7 @@ Maybe BatchGather::Apply(const BatchGatherInterpState* ctx, const TensorTu in_grads->resize(2); if (!ctx->requires_grad) { return Maybe::Ok(); } const auto& indices = ctx->SavedTensors().at(0); - AttrValueMap attrs; + MutableAttrMap attrs; JUST(attrs.SetAttr("num_segments", ctx->num_segments)); in_grads->at(0) = JUST(OpInterpUtil::Dispatch(*bw_unsorted_batch_segment_sum_op_, {out_grads.at(0), indices}, attrs)); diff --git a/oneflow/core/autograd/gradient_funcs/bias_add.cpp b/oneflow/core/autograd/gradient_funcs/bias_add.cpp index b318cf9a414..53969595e83 100644 --- a/oneflow/core/autograd/gradient_funcs/bias_add.cpp +++ b/oneflow/core/autograd/gradient_funcs/bias_add.cpp @@ -43,7 +43,7 @@ class BiasAdd : public OpExprGradFunction { } Maybe Capture(BiasAddInterpState* ctx, const TensorTuple& inputs, - const TensorTuple& outputs, const AttrValueMap& attrs) const override { + const TensorTuple& outputs, const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 2); ctx->input_requires_grad = inputs.at(0)->requires_grad(); ctx->bias_requires_grad = inputs.at(1)->requires_grad(); @@ -60,7 +60,7 @@ class BiasAdd : public OpExprGradFunction { for (int i = 0; i < num_axes; ++i) { if (i != axis_) { reduce_axes_vec.push_back(i); } } - AttrValueMap attrs; + MutableAttrMap attrs; JUST(attrs.SetAttr>("axis", reduce_axes_vec)); in_grads->at(1) = JUST(OpInterpUtil::Dispatch(*backward_bias_op_, {out_grads.at(0)}, attrs)); diff --git a/oneflow/core/autograd/gradient_funcs/broadcast_binary_ops.cpp b/oneflow/core/autograd/gradient_funcs/broadcast_binary_ops.cpp index b3aa09940e2..00802646cf1 100644 --- a/oneflow/core/autograd/gradient_funcs/broadcast_binary_ops.cpp +++ b/oneflow/core/autograd/gradient_funcs/broadcast_binary_ops.cpp @@ -38,7 +38,7 @@ class ReduceSumLikeModule { const auto& in_shape = *(input->shape()); const auto& like_shape = *(like->shape()); TensorTuple inputs{input}; - AttrValueMap attrs; + MutableAttrMap attrs; std::shared_ptr op = identity_op_; if (in_shape != like_shape) { const Shape& left_extended_shape = @@ -77,7 +77,7 @@ class BroadcastBinaryGrad : public OpExprGradFunction { } Maybe Capture(OpExprInterpState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, - const AttrValueMap& attrs) const override { + const AttrMap& attrs) const override { CHECK_EQ_OR_RETURN(inputs.size(), 2); CHECK_EQ_OR_RETURN(outputs.size(), 1); ctx->SaveTensorForBackward(inputs.at(0)); diff --git a/oneflow/core/autograd/gradient_funcs/cast.cpp b/oneflow/core/autograd/gradient_funcs/cast.cpp index 63c1815b2cd..072a09a30de 100644 --- a/oneflow/core/autograd/gradient_funcs/cast.cpp +++ b/oneflow/core/autograd/gradient_funcs/cast.cpp @@ -35,7 +35,7 @@ class Cast : public OpExprGradFunction { } Maybe Capture(OpExprInterpState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, - const AttrValueMap& attrs) const override { + const AttrMap& attrs) const override { ctx->SaveTensorForBackward(inputs.at(0)); return Maybe::Ok(); } @@ -44,7 +44,7 @@ class Cast : public OpExprGradFunction { TensorTuple* in_grads) const override { const auto& x = ctx->SavedTensors().at(0); in_grads->resize(1); - AttrValueMap attrs; + MutableAttrMap attrs; JUST(attrs.SetAttr("dtype", x->dtype()->data_type())); in_grads->at(0) = JUST(OpInterpUtil::Dispatch(*grad_op_, {out_grads.at(0)}, attrs)); return Maybe::Ok(); diff --git a/oneflow/core/autograd/gradient_funcs/deconv.cpp b/oneflow/core/autograd/gradient_funcs/deconv.cpp index 4bdc305df43..0afb818c021 100644 --- a/oneflow/core/autograd/gradient_funcs/deconv.cpp +++ b/oneflow/core/autograd/gradient_funcs/deconv.cpp @@ -32,7 +32,7 @@ class DeConvolutionNd : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(DeConvolutionNdInterpState* ctx, const TensorTuple& inputs, - const TensorTuple& outputs, const AttrValueMap& attrs) const override; + const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const DeConvolutionNdInterpState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; @@ -72,7 +72,7 @@ Maybe DeConvolutionNd::Init(const OpExpr& op) { } Maybe DeConvolutionNd::Capture(DeConvolutionNdInterpState* ctx, const TensorTuple& inputs, - const TensorTuple& outputs, const AttrValueMap& attrs) const { + const TensorTuple& outputs, const AttrMap& attrs) const { ctx->activation_requires_grad = inputs.at(0)->requires_grad(); ctx->weight_requires_grad = inputs.at(1)->requires_grad(); if (ctx->activation_requires_grad) { diff --git a/oneflow/core/autograd/gradient_funcs/default.cpp b/oneflow/core/autograd/gradient_funcs/default.cpp index 0b2e3d5ec5d..ef5cd849782 100644 --- a/oneflow/core/autograd/gradient_funcs/default.cpp +++ b/oneflow/core/autograd/gradient_funcs/default.cpp @@ -44,7 +44,7 @@ class DefaultOpExprGradFunction : public OpExprGradFunction Init(const OpExpr& op) override; Maybe Capture(DefaultOpExprInterpState* ctx, const TensorTuple& inputs, - const TensorTuple& outputs, const AttrValueMap& attrs) const override; + const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const DefaultOpExprInterpState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; @@ -285,7 +285,7 @@ Maybe DefaultOpExprGradFunction::UpdateRequiresBackward(DefaultOpExprInter Maybe DefaultOpExprGradFunction::Capture(DefaultOpExprInterpState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, - const AttrValueMap& attrs) const { + const AttrMap& attrs) const { CHECK_OR_RETURN(attrs.empty()) << "The default op expr gradient func does not support dynamic attributes."; JUST(UpdateRequiresBackward(ctx, inputs)); diff --git a/oneflow/core/autograd/gradient_funcs/layer_norm.cpp b/oneflow/core/autograd/gradient_funcs/layer_norm.cpp index 7a62a933570..00b85818112 100644 --- a/oneflow/core/autograd/gradient_funcs/layer_norm.cpp +++ b/oneflow/core/autograd/gradient_funcs/layer_norm.cpp @@ -38,7 +38,7 @@ class LayerNorm : public OpExprGradFunction { Maybe Init(const OpExpr& op) override; Maybe Capture(LayerNormInterpState* ctx, const TensorTuple& inputs, - const TensorTuple& outputs, const AttrValueMap& attrs) const override; + const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const LayerNormInterpState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; @@ -68,7 +68,7 @@ Maybe LayerNorm::Init(const OpExpr& op) { } Maybe LayerNorm::Capture(LayerNormInterpState* ctx, const TensorTuple& inputs, - const TensorTuple& outputs, const AttrValueMap& attrs) const { + const TensorTuple& outputs, const AttrMap& attrs) const { CHECK_EQ_OR_RETURN(inputs.size(), center_ + scale_ + 1); CHECK_EQ_OR_RETURN(inputs.size(), scale_ + 3); ctx->has_beta_diff = center_ && inputs.at(1)->requires_grad(); @@ -122,7 +122,7 @@ Maybe LayerNorm::Apply(const LayerNormInterpState* ctx, const TensorTuple& const auto& x = saved_tensors.at(offset); const auto& mean = saved_tensors.at(offset + 1); const auto& inv_variance = saved_tensors.at(offset + 2); - AttrValueMap attrs; + MutableAttrMap attrs; JUST(attrs.SetAttr("begin_norm_axis", begin_norm_axis)); in_grads->at(0) = JUST(OpInterpUtil::Dispatch(*x_grad_op_, {x, mean, inv_variance, dy}, attrs)); diff --git a/oneflow/core/autograd/gradient_funcs/normalization.cpp b/oneflow/core/autograd/gradient_funcs/normalization.cpp index 45773ca02c2..46bec025154 100644 --- a/oneflow/core/autograd/gradient_funcs/normalization.cpp +++ b/oneflow/core/autograd/gradient_funcs/normalization.cpp @@ -67,7 +67,7 @@ class NormalizationGrad : public OpExprGradFunction Capture(NormalizationGradInterpState* ctx, const TensorTuple& inputs, - const TensorTuple& outputs, const AttrValueMap& attrs) const override { + const TensorTuple& outputs, const AttrMap& attrs) const override { ctx->is_training = JUST(op_trait_->GetAttr("training", attrs)); ctx->SaveTensorForBackward(inputs.at(0)); // x ctx->SaveTensorForBackward(inputs.at(3)); // gamma @@ -118,7 +118,7 @@ class NormalizationGrad : public OpExprGradFunctionshape()->At(axis_)); } } - AttrValueMap shape_attr; + MutableAttrMap shape_attr; shape_attr.SetAttr("shape", Shape(dim_vec)); const auto& reshaped_gamma = JUST(OpInterpUtil::Dispatch(*reshape_gamma_op_, {gamma}, shape_attr)); diff --git a/oneflow/core/autograd/gradient_funcs/reshape.cpp b/oneflow/core/autograd/gradient_funcs/reshape.cpp index 13b1e7b5c6d..dcc99995288 100644 --- a/oneflow/core/autograd/gradient_funcs/reshape.cpp +++ b/oneflow/core/autograd/gradient_funcs/reshape.cpp @@ -34,7 +34,7 @@ class ReshapeOpExprGrad : public OpExprGradFunction { } Maybe Capture(OpExprInterpState* ctx, const TensorTuple& inputs, const TensorTuple& outputs, - const AttrValueMap& attrs) const override { + const AttrMap& attrs) const override { ctx->SaveTensorForBackward(inputs.at(0)); return Maybe::Ok(); } diff --git a/oneflow/core/autograd/gradient_funcs/split_like.cpp b/oneflow/core/autograd/gradient_funcs/split_like.cpp index 7e1462d0866..c128a7dd120 100644 --- a/oneflow/core/autograd/gradient_funcs/split_like.cpp +++ b/oneflow/core/autograd/gradient_funcs/split_like.cpp @@ -32,7 +32,7 @@ class SplitLike : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(SplitLikeInterpState* ctx, const TensorTuple& inputs, - const TensorTuple& outputs, const AttrValueMap& attrs) const override; + const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const SplitLikeInterpState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; @@ -61,7 +61,7 @@ Maybe SplitLike::Init(const OpExpr& op) { } Maybe SplitLike::Capture(SplitLikeInterpState* ctx, const TensorTuple& inputs, - const TensorTuple& outputs, const AttrValueMap& attrs) const { + const TensorTuple& outputs, const AttrMap& attrs) const { CHECK_EQ_OR_RETURN(inputs.size(), outputs.size() + 1); ctx->requires_grad = inputs.at(0)->requires_grad(); if (!ctx->requires_grad) { return Maybe::Ok(); } @@ -92,7 +92,7 @@ Maybe SplitLike::Apply(const SplitLikeInterpState* ctx, const TensorTuple& inputs.push_back(zero_grad); } } - AttrValueMap concat_attrs; + MutableAttrMap concat_attrs; concat_attrs.SetAttr("axis", axis_); concat_attrs.SetAttr("max_dim_size", ctx->max_dim_size); in_grads->at(0) = JUST(OpInterpUtil::Dispatch(*concat_op_, inputs, concat_attrs)); diff --git a/oneflow/core/autograd/gradient_funcs/tensor_scalar_binary.cpp b/oneflow/core/autograd/gradient_funcs/tensor_scalar_binary.cpp index 774fb0ff675..1becec9a05f 100644 --- a/oneflow/core/autograd/gradient_funcs/tensor_scalar_binary.cpp +++ b/oneflow/core/autograd/gradient_funcs/tensor_scalar_binary.cpp @@ -34,7 +34,7 @@ class TensorScalarAddOrSub : public OpExprGradFunction Maybe Init(const OpExpr& op) override; Maybe Capture(TensorScalarInterpState* ctx, const TensorTuple& inputs, - const TensorTuple& outputs, const AttrValueMap& attrs) const override; + const TensorTuple& outputs, const AttrMap& attrs) const override; protected: std::shared_ptr identity_op_; @@ -56,8 +56,7 @@ Maybe TensorScalarAddOrSub::Init(const OpExpr& op) { } Maybe TensorScalarAddOrSub::Capture(TensorScalarInterpState* ctx, const TensorTuple& inputs, - const TensorTuple& outputs, - const AttrValueMap& attrs) const { + const TensorTuple& outputs, const AttrMap& attrs) const { ctx->x_requires_grad = inputs.at(0)->requires_grad(); ctx->scalar_requires_grad = inputs.at(1)->requires_grad(); return Maybe::Ok(); @@ -75,7 +74,7 @@ class TensorScalarAdd : public TensorScalarAddOrSub { int32_t num_axes = out_grads.at(0)->shape()->NumAxes(); std::vector axes_vec(num_axes); std::iota(axes_vec.begin(), axes_vec.end(), 0); - AttrValueMap attrs; + MutableAttrMap attrs; JUST(attrs.SetAttr>("axis", axes_vec)); in_grads->at(1) = JUST(OpInterpUtil::Dispatch(*reduce_sum_op_, {out_grads.at(0)}, attrs)); @@ -97,7 +96,7 @@ class TensorScalarSub : public TensorScalarAddOrSub { int32_t num_axes = out_grads.at(0)->shape()->NumAxes(); std::vector axes_vec(num_axes); std::iota(axes_vec.begin(), axes_vec.end(), 0); - AttrValueMap attrs; + MutableAttrMap attrs; JUST(attrs.SetAttr>("axis", axes_vec)); const auto& reduce_sum = JUST(OpInterpUtil::Dispatch(*reduce_sum_op_, {out_grads.at(0)}, attrs)); @@ -114,7 +113,7 @@ class TensorScalarMul : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(TensorScalarInterpState* ctx, const TensorTuple& inputs, - const TensorTuple& outputs, const AttrValueMap& attrs) const override; + const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const TensorScalarInterpState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; @@ -136,7 +135,7 @@ Maybe TensorScalarMul::Init(const OpExpr& op) { } Maybe TensorScalarMul::Capture(TensorScalarInterpState* ctx, const TensorTuple& inputs, - const TensorTuple& outputs, const AttrValueMap& attrs) const { + const TensorTuple& outputs, const AttrMap& attrs) const { ctx->x_requires_grad = inputs.at(0)->requires_grad(); ctx->scalar_requires_grad = inputs.at(1)->requires_grad(); if (ctx->x_requires_grad) { ctx->SaveTensorForBackward(inputs.at(1)); } @@ -158,7 +157,7 @@ Maybe TensorScalarMul::Apply(const TensorScalarInterpState* ctx, const Ten int32_t num_axes = out_grads.at(0)->shape()->NumAxes(); std::vector axes_vec(num_axes); std::iota(axes_vec.begin(), axes_vec.end(), 0); - AttrValueMap attrs; + MutableAttrMap attrs; JUST(attrs.SetAttr>("axis", axes_vec)); in_grads->at(1) = JUST(OpInterpUtil::Dispatch(*reduce_sum_op_, {y}, attrs)); } @@ -171,7 +170,7 @@ class TensorScalarDiv : public OpExprGradFunction { public: Maybe Init(const OpExpr& op) override; Maybe Capture(TensorScalarInterpState* ctx, const TensorTuple& inputs, - const TensorTuple& outputs, const AttrValueMap& attrs) const override; + const TensorTuple& outputs, const AttrMap& attrs) const override; Maybe Apply(const TensorScalarInterpState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const override; @@ -190,7 +189,7 @@ Maybe TensorScalarDiv::Init(const OpExpr& op) { } Maybe TensorScalarDiv::Capture(TensorScalarInterpState* ctx, const TensorTuple& inputs, - const TensorTuple& outputs, const AttrValueMap& attrs) const { + const TensorTuple& outputs, const AttrMap& attrs) const { ctx->x_requires_grad = inputs.at(0)->requires_grad(); ctx->scalar_requires_grad = inputs.at(1)->requires_grad(); if (ctx->x_requires_grad || ctx->scalar_requires_grad) { diff --git a/oneflow/core/common/shape.cpp b/oneflow/core/common/shape.cpp index d301ec33478..949935d361a 100644 --- a/oneflow/core/common/shape.cpp +++ b/oneflow/core/common/shape.cpp @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/shape.h" +#include "oneflow/core/common/shape.cfg.h" #include "oneflow/core/common/shape_view.h" #include "oneflow/core/common/protobuf.h" @@ -58,6 +59,11 @@ Shape::Shape(const ShapeProto& shape_proto) { UpdateElemCnt(); } +Shape::Shape(const cfg::ShapeProto& shape_proto) { + dim_vec_.assign(shape_proto.dim().begin(), shape_proto.dim().end()); + UpdateElemCnt(); +} + Shape& Shape::operator=(const Shape& shape) { dim_vec_ = shape.dim_vec_; UpdateElemCnt(); diff --git a/oneflow/core/common/shape.h b/oneflow/core/common/shape.h index d6210e653c3..513ec410f28 100644 --- a/oneflow/core/common/shape.h +++ b/oneflow/core/common/shape.h @@ -24,6 +24,10 @@ namespace oneflow { class ShapeView; +namespace cfg { +class ShapeProto; +} + class Shape final { public: // OF_DISALLOW_COPY_AND_MOVE(Shape); @@ -31,6 +35,7 @@ class Shape final { explicit Shape(const DimVector& dim_vec); explicit Shape(DimVector&& dim_vec); explicit Shape(const ShapeProto& shape_proto); + explicit Shape(const cfg::ShapeProto& shape_proto); Shape(const std::initializer_list& dim_vec); ~Shape() = default; Shape& operator=(const Shape& shape); diff --git a/oneflow/core/eager/local_call_opkernel_phy_instr_operand.h b/oneflow/core/eager/local_call_opkernel_phy_instr_operand.h index 4751452552a..6847422d9f4 100644 --- a/oneflow/core/eager/local_call_opkernel_phy_instr_operand.h +++ b/oneflow/core/eager/local_call_opkernel_phy_instr_operand.h @@ -17,7 +17,7 @@ limitations under the License. #define ONEFLOW_CORE_EAGER_LOCAL_CALL_OPKERNEL_PHY_INSTR_OPERAND_H_ #include "oneflow/core/eager/eager_blob_object.h" -#include "oneflow/core/framework/attr_value_map.h" +#include "oneflow/core/framework/attr_map.h" #include "oneflow/core/vm/instruction_operand.msg.h" namespace oneflow { @@ -47,14 +47,13 @@ class LocalCallOpKernelPhyInstrOperand final : public vm::PhyInstrOperand { LocalCallOpKernelPhyInstrOperand(const std::shared_ptr& opkernel, const one::EagerBlobObjectList inputs, - const one::EagerBlobObjectList outputs, - const AttrValueMap& attrs) + const one::EagerBlobObjectList outputs, const AttrMap& attrs) : opkernel_(opkernel), inputs_(inputs), outputs_(outputs), attrs_(attrs) {} const one::StatefulOpKernel& opkernel() const { return *opkernel_; } const one::EagerBlobObjectList& inputs() const { return inputs_; } const one::EagerBlobObjectList& outputs() const { return outputs_; } - const AttrValueMap& attrs() const { return attrs_; } + const AttrMap& attrs() const { return attrs_; } one::StatefulOpKernel* mut_opkernel() { return opkernel_.get(); } @@ -84,7 +83,7 @@ class LocalCallOpKernelPhyInstrOperand final : public vm::PhyInstrOperand { std::shared_ptr opkernel_; one::EagerBlobObjectList inputs_; one::EagerBlobObjectList outputs_; - const AttrValueMap attrs_; + const AttrMap attrs_; const user_op::OpKernel* user_opkernel_; }; diff --git a/oneflow/core/framework/attr_map.cpp b/oneflow/core/framework/attr_map.cpp new file mode 100644 index 00000000000..1df20eb2c63 --- /dev/null +++ b/oneflow/core/framework/attr_map.cpp @@ -0,0 +1,112 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include "oneflow/core/framework/attr_map.h" +#include "oneflow/core/framework/attr_value.h" +#include "oneflow/core/framework/attr_value_accessor.h" + +namespace oneflow { + +AttrMap::AttrMap(std::initializer_list init) : attrs_(new AttrName2AttrVal) { + for (const auto& pair : init) { attrs_->emplace(pair.first, pair.second); } +} + +AttrMap::AttrMap(const MutableAttrMap& other) : attrs_(new AttrName2AttrVal) { + for (const auto& pair : other) { attrs_->emplace(pair.first, pair.second); } +} + +AttrMap::AttrMap(const MutableCfgAttrMap& other) : attrs_(new AttrName2AttrVal) { + for (const auto& pair : other) { + const auto& attr_value = CHECK_JUST(user_op::AttrValueUtil::ToCppAttrValue(*pair.second)); + attrs_->emplace(pair.first, attr_value); + } +} + +template +Maybe AttrMap::GetAttr(const std::string& attr_name) const { + const auto& it = this->find(attr_name); + CHECK_OR_RETURN(it != this->end()); + const auto* ptr = dynamic_cast*>(it->second.get()); + CHECK_NOTNULL_OR_RETURN(ptr); + return ptr->val(); +} + +template +Maybe ComposedAttrMap::GetAttr(const std::string& attr_name) const { + { + const auto& it = prior_.find(attr_name); + if (it != prior_.end()) { + const auto* ptr = dynamic_cast*>(it->second.get()); + CHECK_NOTNULL_OR_RETURN(ptr); + return ptr->val(); + } + } + { + const auto& it = base_.find(attr_name); + if (it != base_.end()) { + const auto* ptr = dynamic_cast*>(it->second.get()); + CHECK_NOTNULL_OR_RETURN(ptr); + return ptr->val(); + } + } + return Error::ValueError(std::string("no attribute found. attribute name: ") + attr_name); +} + +#define DEFINE_ATTR_VALUE_MAP_GET_ATTR(field, T, attr_type) \ + template Maybe AttrMap::GetAttr(const std::string& attr_name) const; \ + template Maybe ComposedAttrMap::GetAttr(const std::string& attr_name) const; + +OF_PP_FOR_EACH_TUPLE(DEFINE_ATTR_VALUE_MAP_GET_ATTR, ATTR_SEQ); +#undef DEFINE_ATTR_VALUE_MAP_GET_ATTR + +template<> +Maybe MutableAttrMap::SetAttr(const std::string& attr_name, + const std::shared_ptr& attr_val) { + (*this)[attr_name] = attr_val; + return Maybe::Ok(); +} + +template +Maybe MutableAttrMap::SetAttr(const std::string& attr_name, const T& attr_val) { + (*this)[attr_name] = std::make_shared>(attr_val); + return Maybe::Ok(); +} + +template<> +Maybe MutableCfgAttrMap::SetAttr(const std::string& attr_name, + const std::shared_ptr& attr_val) { + (*this)[attr_name] = attr_val; + return Maybe::Ok(); +} + +template +Maybe MutableCfgAttrMap::SetAttr(const std::string& attr_name, const T& attr_val) { + AttrValue proto_attr_val; + user_op::AttrValueAccessor::Attr(attr_val, &proto_attr_val); + (*this)[attr_name] = std::make_shared(proto_attr_val); + return Maybe::Ok(); +} + +#define DEFINE_ATTR_VALUE_MAP_SET_ATTR(field, T, attr_type) \ + template Maybe MutableAttrMap::SetAttr(const std::string& attr_name, \ + const T& attr_val); \ + template Maybe MutableCfgAttrMap::SetAttr(const std::string& attr_name, \ + const T& attr_val); + +OF_PP_FOR_EACH_TUPLE(DEFINE_ATTR_VALUE_MAP_SET_ATTR, ATTR_SEQ); +#undef DEFINE_ATTR_VALUE_MAP_SET_ATTR + +} // namespace oneflow diff --git a/oneflow/core/framework/attr_map.h b/oneflow/core/framework/attr_map.h new file mode 100644 index 00000000000..cbf5a026b31 --- /dev/null +++ b/oneflow/core/framework/attr_map.h @@ -0,0 +1,100 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_CORE_FRAMEWORK_ATTR_MAP_H_ +#define ONEFLOW_CORE_FRAMEWORK_ATTR_MAP_H_ + +#include "oneflow/core/common/util.h" +#include "oneflow/core/framework/user_op_attr.cfg.h" +#include "oneflow/core/framework/user_op_conf.h" + +namespace oneflow { + +class MutableAttrMap; +class MutableCfgAttrMap; + +using AttrName2AttrVal = HashMap>; + +class AttrMap final { + public: + AttrMap() : attrs_(new AttrName2AttrVal) {} + explicit AttrMap(const std::shared_ptr& attrs) : attrs_(attrs) {} + + using value_type = typename AttrName2AttrVal::value_type; + AttrMap(std::initializer_list init); + + AttrMap(const MutableAttrMap& other); // without coping AttrVal + AttrMap(const MutableCfgAttrMap& other); + + AttrMap(const AttrMap&) = default; + AttrMap(AttrMap&&) = default; + ~AttrMap() = default; + + AttrMap& operator=(const AttrMap& other) { + attrs_ = other.attrs_; + return *this; + } + + template + Maybe GetAttr(const std::string& attr_name) const; + + size_t size() const { return attrs_->size(); } + bool empty() const { return attrs_->empty(); } + + using const_iterator = typename AttrName2AttrVal::const_iterator; + const_iterator begin() const { return attrs_->begin(); } + const_iterator end() const { return attrs_->end(); } + + const_iterator find(const std::string& attr_name) const { return attrs_->find(attr_name); } + + private: + std::shared_ptr attrs_; +}; + +class ComposedAttrMap final { + public: + ComposedAttrMap(const AttrMap& base) : base_(base) {} + ComposedAttrMap(const AttrMap& prior, const AttrMap& base) : prior_(prior), base_(base) {} + + template + Maybe GetAttr(const std::string& attr_name) const; + + void ResetPrior(const AttrMap& prior) { prior_ = prior; } + void ResetBase(const AttrMap& base) { base_ = base; } + + private: + AttrMap prior_; + AttrMap base_; +}; + +class MutableAttrMap : public HashMap> { + public: + using HashMap>::HashMap; + + template + Maybe SetAttr(const std::string& attr_name, const T& attr_val); +}; + +class MutableCfgAttrMap : public HashMap> { + public: + using HashMap>::HashMap; + + template + Maybe SetAttr(const std::string& attr_name, const T& attr_val); +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_FRAMEWORK_ATTR_MAP_H_ diff --git a/oneflow/core/framework/attr_value_accessor.cpp b/oneflow/core/framework/attr_value_accessor.cpp index dcf7dba80be..e18cdf9e36e 100644 --- a/oneflow/core/framework/attr_value_accessor.cpp +++ b/oneflow/core/framework/attr_value_accessor.cpp @@ -17,6 +17,8 @@ limitations under the License. #include "oneflow/core/common/util.h" #include "oneflow/core/common/shape.h" #include "oneflow/core/common/protobuf.h" +#include "oneflow/core/framework/user_op_conf.h" +#include "oneflow/core/framework/user_op_attr.cfg.h" namespace oneflow { @@ -30,6 +32,11 @@ namespace user_op { return val.field(); \ } \ template<> \ + cpp_type AttrValueAccessor::Attr(const cfg::AttrValue& val) { \ + CHECK(val.has_##field()); \ + return static_cast(val.field()); \ + } \ + template<> \ void AttrValueAccessor::Attr(const cpp_type& cpp_val, AttrValue* attr_val) { \ attr_val->set_##field(cpp_val); \ } @@ -48,6 +55,12 @@ template<> Shape AttrValueAccessor::Attr(const AttrValue& val) { return Shape(val.at_shape()); } + +template<> +Shape AttrValueAccessor::Attr(const cfg::AttrValue& val) { + return Shape(val.at_shape()); +} + template<> void AttrValueAccessor::Attr(const Shape& cpp_val, AttrValue* attr_val) { cpp_val.ToProto(attr_val->mutable_at_shape()); @@ -60,6 +73,11 @@ void AttrValueAccessor::Attr(const Shape& cpp_val, AttrValue* attr_val) { return PbRf2StdVec(val.field().val()); \ } \ template<> \ + cpp_type AttrValueAccessor::Attr(const cfg::AttrValue& val) { \ + const auto& rp_val = val.field().val(); \ + return cpp_type(rp_val.begin(), rp_val.end()); \ + } \ + template<> \ void AttrValueAccessor::Attr(const cpp_type& cpp_val, AttrValue* attr_val) { \ *(attr_val->mutable_##field()->mutable_val()) = StdVec2PbRf(cpp_val); \ } @@ -80,6 +98,15 @@ OF_PP_FOR_EACH_TUPLE(LIST_BASIC_ATTR_SEQ_ENTRY, LIST_BASIC_ATTR_SEQ) return ret; \ } \ template<> \ + cpp_type AttrValueAccessor::Attr(const cfg::AttrValue& val) { \ + std::vector ret; \ + ret.reserve(val.field().val_size()); \ + for (const auto& value : val.field().val()) { \ + ret.emplace_back(static_cast(value)); \ + } \ + return ret; \ + } \ + template<> \ void AttrValueAccessor::Attr(const cpp_type& cpp_val, AttrValue* attr_val) { \ using proto_type = std::remove_reference_tfield().val())>::value_type; \ std::vector vec; \ @@ -120,6 +147,47 @@ void AttrValueAccessor>::Attr(const std::vectormutable_at_list_string()->mutable_val()) = StdVec2PbRpf(cpp_val); } +template +Maybe MakeCppAttrValueFromProtoOrCfgAttrValue(const ProtoT& cfg_attr_value) { + switch (static_cast(cfg_attr_value.value_case())) { +#define MAKE_ENTRY(field, T, attr_type) \ + case static_cast(attr_type): \ + return std::static_pointer_cast( \ + std::make_shared>(AttrValueAccessor::Attr(cfg_attr_value))); + OF_PP_FOR_EACH_TUPLE(MAKE_ENTRY, ATTR_SEQ); +#undef MAKE_ENTRY + default: OF_UNIMPLEMENTED(); + } +} + +/*static*/ Maybe AttrValueUtil::ToCppAttrValue(const AttrValue& proto_attr_value) { + return MakeCppAttrValueFromProtoOrCfgAttrValue(proto_attr_value); +} + +/*static*/ Maybe AttrValueUtil::ToCppAttrValue(const cfg::AttrValue& cfg_attr_value) { + AttrValue proto_attr_value; + cfg_attr_value.ToProto(&proto_attr_value); + return MakeCppAttrValueFromProtoOrCfgAttrValue(proto_attr_value); +} + +/*static*/ Maybe AttrValueUtil::ToProtoAttrValue(const AttrVal& cpp_attr_value, + AttrValue* attr_value) { + if (false) { +// clang-format off +#define MAKE_ENTRY(field, cpp_type, attr_type) \ + } \ + else if (dynamic_cast*>(&cpp_attr_value) != nullptr) { \ + const auto* ptr = dynamic_cast*>(&cpp_attr_value); \ + AttrValueAccessor::Attr(ptr->val(), attr_value); + OF_PP_FOR_EACH_TUPLE(MAKE_ENTRY, ATTR_SEQ); +#undef MAKE_ENTRY + // clang-format on + } else { + OF_UNIMPLEMENTED(); + } + return Maybe::Ok(); +} + } // namespace user_op } // namespace oneflow diff --git a/oneflow/core/framework/attr_value_accessor.h b/oneflow/core/framework/attr_value_accessor.h index c1a20ac7f36..25059c5158c 100644 --- a/oneflow/core/framework/attr_value_accessor.h +++ b/oneflow/core/framework/attr_value_accessor.h @@ -17,17 +17,31 @@ limitations under the License. #define ONEFLOW_CORE_FRAMEWORK_ATTR_VAL_ACCESSOR_H_ #include "oneflow/core/framework/attr_value.h" +#include "oneflow/core/common/maybe.h" namespace oneflow { +namespace cfg { +class AttrValue; +} + namespace user_op { template struct AttrValueAccessor final { static T Attr(const AttrValue&); + static T Attr(const cfg::AttrValue&); static void Attr(const T&, AttrValue*); }; +class AttrVal; + +struct AttrValueUtil final { + static Maybe ToCppAttrValue(const AttrValue& proto_attr_value); + static Maybe ToCppAttrValue(const cfg::AttrValue& cfg_attr_value); + static Maybe ToProtoAttrValue(const AttrVal& cpp_attr_value, AttrValue* attr_value); +}; + } // namespace user_op } // namespace oneflow diff --git a/oneflow/core/framework/attr_value_map.cpp b/oneflow/core/framework/attr_value_map.cpp deleted file mode 100644 index 6686d2e6145..00000000000 --- a/oneflow/core/framework/attr_value_map.cpp +++ /dev/null @@ -1,66 +0,0 @@ -/* -Copyright 2020 The OneFlow Authors. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -#include "oneflow/core/framework/attr_value_map.h" -#include "oneflow/core/framework/attr_value_accessor.h" - -namespace oneflow { - -template<> -Maybe AttrValueMap::GetAttr(const std::string& attr_name) const { - const auto& it = this->find(attr_name); - CHECK_OR_RETURN(it != this->end()); - return it->second; -} - -template<> -Maybe AttrValueMap::SetAttr(const std::string& attr_name, - const std::shared_ptr& attr_val) { - auto it = this->find(attr_name); - if (it == this->end()) { - this->emplace(attr_name, attr_val); - } else { - it->second = attr_val; - } - return Maybe::Ok(); -} - -#define DEFINE_ATTR_VALUE_MAP_ATTRIBUTE_GETTER(field, cpp_type, attr_type) \ - template<> \ - Maybe AttrValueMap::GetAttr(const std::string& attr_name) const { \ - const auto& it = this->find(attr_name); \ - CHECK_OR_RETURN(it != this->end()); \ - AttrValue attr_vale; \ - it->second->ToProto(&attr_vale); \ - return user_op::AttrValueAccessor::Attr(attr_vale); \ - } - -#define DEFINE_ATTR_VALUE_MAP_ATTRIBUTE_SETTER(field, cpp_type, attr_type) \ - template<> \ - Maybe AttrValueMap::SetAttr(const std::string& attr_name, const cpp_type& val) { \ - AttrValue attr_val; \ - user_op::AttrValueAccessor::Attr(val, &attr_val); \ - SetAttr(attr_name, std::make_shared(attr_val)); \ - return Maybe::Ok(); \ - } - -OF_PP_FOR_EACH_TUPLE(DEFINE_ATTR_VALUE_MAP_ATTRIBUTE_GETTER, ATTR_SEQ); -OF_PP_FOR_EACH_TUPLE(DEFINE_ATTR_VALUE_MAP_ATTRIBUTE_SETTER, ATTR_SEQ); - -#undef DEFINE_ATTR_VALUE_MAP_ATTRIBUTE_GETTER -#undef DEFINE_ATTR_VALUE_MAP_ATTRIBUTE_SETTER - -} // namespace oneflow diff --git a/oneflow/core/framework/attr_value_map.h b/oneflow/core/framework/attr_value_map.h deleted file mode 100644 index 51aa8d9aa6b..00000000000 --- a/oneflow/core/framework/attr_value_map.h +++ /dev/null @@ -1,37 +0,0 @@ -/* -Copyright 2020 The OneFlow Authors. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ -#ifndef ONEFLOW_CORE_FRAMEWORK_ATTR_VALUE_MAP_H_ -#define ONEFLOW_CORE_FRAMEWORK_ATTR_VALUE_MAP_H_ - -#include "oneflow/core/common/util.h" -#include "oneflow/core/framework/user_op_attr.cfg.h" - -namespace oneflow { - -class AttrValueMap : public HashMap> { - public: - using HashMap>::HashMap; - - template - Maybe GetAttr(const std::string& attr_name) const; - - template - Maybe SetAttr(const std::string& attr_name, const T& attr_val); -}; - -} // namespace oneflow - -#endif // ONEFLOW_CORE_FRAMEWORK_ATTR_VALUE_MAP_H_ diff --git a/oneflow/core/framework/instructions_builder.cpp b/oneflow/core/framework/instructions_builder.cpp index 81ddfe79e6f..b2f938dc17c 100644 --- a/oneflow/core/framework/instructions_builder.cpp +++ b/oneflow/core/framework/instructions_builder.cpp @@ -646,7 +646,7 @@ Maybe InstructionsBuilder::BuildRecvInstruction( Maybe InstructionsBuilder::LocalCallOpKernel( const std::shared_ptr& opkernel, one::EagerBlobObjectList input_eager_blob_objects, - one::EagerBlobObjectList output_eager_blob_objects, const AttrValueMap& attrs, + one::EagerBlobObjectList output_eager_blob_objects, const AttrMap& attrs, const std::shared_ptr& parallel_desc_sym) { ObjectMsgPtr instruction = ObjectMsgPtr::New(parallel_desc_sym->device_tag() + ".LocalCallOpKernel"); diff --git a/oneflow/core/framework/instructions_builder.h b/oneflow/core/framework/instructions_builder.h index 61dd4114ddf..b36bcdf96cd 100644 --- a/oneflow/core/framework/instructions_builder.h +++ b/oneflow/core/framework/instructions_builder.h @@ -249,7 +249,7 @@ class InstructionsBuilder : public std::enable_shared_from_this LocalCallOpKernel(const std::shared_ptr& opkernel, one::EagerBlobObjectList input_eager_blob_objects, one::EagerBlobObjectList output_eager_blob_objects, - const AttrValueMap& attrs, + const AttrMap& attrs, const std::shared_ptr& parallel_desc_sym); private: diff --git a/oneflow/core/framework/op_expr.cpp b/oneflow/core/framework/op_expr.cpp index 1c5ad30e9dd..d0a43f9cc90 100644 --- a/oneflow/core/framework/op_expr.cpp +++ b/oneflow/core/framework/op_expr.cpp @@ -13,8 +13,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "oneflow/core/common/auto_registration_factory.h" #include "oneflow/core/framework/op_expr.h" + +#include "oneflow/core/common/auto_registration_factory.h" +#include "oneflow/core/framework/attr_value_accessor.h" #include "oneflow/core/framework/op_expr_grad_function.h" #include "oneflow/core/framework/user_op_registry_manager.h" #include "oneflow/user/kernels/stateful_local_opkernel.h" @@ -65,13 +67,13 @@ DEFINE_OPEXPR_TYPE_NAME(DistributeAddOpConf, "distribute_add"); template<> Maybe BuiltinOpExprImpl::BuildOpConf(OperatorConf* op_conf, - const AttrValueMap& attrs) const { + const AttrMap& attrs) const { *(op_conf->mutable_name()) = op_name_; *(op_conf->mutable_user_conf()) = op_proto_; auto* user_op_conf = op_conf->mutable_user_conf(); for (const auto& it : attrs) { AttrValue attr_val; - it.second->ToProto(&attr_val); + user_op::AttrValueUtil::ToProtoAttrValue(*it.second, &attr_val); (*(user_op_conf->mutable_attr()))[it.first] = attr_val; } return Maybe::Ok(); @@ -122,7 +124,7 @@ Maybe BuiltinOpExprImpl::GetOrCreateOpGradClosure template<> Maybe BuiltinOpExprImpl::BuildOpConf(OperatorConf* op_conf, - const AttrValueMap& attrs) const { + const AttrMap& attrs) const { CHECK_EQ_OR_RETURN(attrs.size(), 0); *(op_conf->mutable_name()) = op_name_; *(op_conf->mutable_variable_conf()) = op_proto_; @@ -141,7 +143,7 @@ Maybe BuiltinOpExprImpl::GetOrCreateOpGradClo template<> Maybe BuiltinOpExprImpl::BuildOpConf(OperatorConf* op_conf, - const AttrValueMap& attrs) const { + const AttrMap& attrs) const { CHECK_EQ_OR_RETURN(attrs.size(), 0); *(op_conf->mutable_name()) = op_name_; *(op_conf->mutable_cast_to_mirrored_conf()) = op_proto_; @@ -154,8 +156,8 @@ Maybe BuiltinOpExprImpl::GetOrCreateOpG } template<> -Maybe BuiltinOpExprImpl::BuildOpConf( - OperatorConf* op_conf, const AttrValueMap& attrs) const { +Maybe BuiltinOpExprImpl::BuildOpConf(OperatorConf* op_conf, + const AttrMap& attrs) const { CHECK_EQ_OR_RETURN(attrs.size(), 0); *(op_conf->mutable_name()) = op_name_; *(op_conf->mutable_cast_from_mirrored_conf()) = op_proto_; @@ -170,7 +172,7 @@ Maybe BuiltinOpExprImpl::GetOrCreateO template<> Maybe BuiltinOpExprImpl::BuildOpConf(OperatorConf* op_conf, - const AttrValueMap& attrs) const { + const AttrMap& attrs) const { CHECK_EQ_OR_RETURN(attrs.size(), 0); *(op_conf->mutable_name()) = op_name_; *(op_conf->mutable_distribute_split_conf()) = op_proto_; @@ -185,7 +187,7 @@ Maybe BuiltinOpExprImpl::GetOrCreateOp template<> Maybe BuiltinOpExprImpl::BuildOpConf(OperatorConf* op_conf, - const AttrValueMap& attrs) const { + const AttrMap& attrs) const { CHECK_EQ_OR_RETURN(attrs.size(), 0); *(op_conf->mutable_name()) = op_name_; *(op_conf->mutable_distribute_clone_conf()) = op_proto_; @@ -199,8 +201,8 @@ Maybe BuiltinOpExprImpl::GetOrCreateOp } template<> -Maybe BuiltinOpExprImpl::BuildOpConf( - OperatorConf* op_conf, const AttrValueMap& attrs) const { +Maybe BuiltinOpExprImpl::BuildOpConf(OperatorConf* op_conf, + const AttrMap& attrs) const { CHECK_EQ_OR_RETURN(attrs.size(), 0); *(op_conf->mutable_name()) = op_name_; *(op_conf->mutable_distribute_concat_conf()) = op_proto_; @@ -215,7 +217,7 @@ Maybe BuiltinOpExprImpl::GetOrCreateO template<> Maybe BuiltinOpExprImpl::BuildOpConf(OperatorConf* op_conf, - const AttrValueMap& attrs) const { + const AttrMap& attrs) const { CHECK_EQ_OR_RETURN(attrs.size(), 0); *(op_conf->mutable_name()) = op_name_; *(op_conf->mutable_distribute_add_conf()) = op_proto_; diff --git a/oneflow/core/framework/op_expr.h b/oneflow/core/framework/op_expr.h index 1a9f92e10d5..1a53766f291 100644 --- a/oneflow/core/framework/op_expr.h +++ b/oneflow/core/framework/op_expr.h @@ -18,7 +18,7 @@ limitations under the License. #include "oneflow/core/common/util.h" #include "oneflow/core/operator/op_conf.pb.h" -#include "oneflow/core/framework/attr_value_map.h" +#include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/device.h" #include "oneflow/core/framework/tensor_tuple.h" #include "oneflow/core/framework/user_op_conf.pb.h" @@ -67,7 +67,7 @@ class BuiltinOpExpr : public OpExpr { return indexed_output_pairs_; } - virtual Maybe BuildOpConf(OperatorConf* op_conf, const AttrValueMap& attrs) const = 0; + virtual Maybe BuildOpConf(OperatorConf* op_conf, const AttrMap& attrs) const = 0; protected: std::string op_name_; @@ -98,7 +98,7 @@ class BuiltinOpExprImpl : public BuiltinOpExpr { Maybe GetOrCreateOpGradClosure() const override; - Maybe BuildOpConf(OperatorConf* op_conf, const AttrValueMap& attrs) const override; + Maybe BuildOpConf(OperatorConf* op_conf, const AttrMap& attrs) const override; protected: ProtoType op_proto_; diff --git a/oneflow/core/framework/op_expr_grad_function.h b/oneflow/core/framework/op_expr_grad_function.h index a6d25c37a92..ea637d931d3 100644 --- a/oneflow/core/framework/op_expr_grad_function.h +++ b/oneflow/core/framework/op_expr_grad_function.h @@ -37,7 +37,7 @@ class OpExprGradFunctionIf { // Capture forward inputs and outputs for backward. virtual Maybe CaptureIf(OpExprInterpState* ctx, const TensorTuple& inputs, - const TensorTuple& outputs, const AttrValueMap& attrs) const = 0; + const TensorTuple& outputs, const AttrMap& attrs) const = 0; virtual Maybe ApplyIf(const OpExprInterpState* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const = 0; @@ -51,7 +51,7 @@ class OpExprGradFunction : public OpExprGradFunctionIf { } Maybe CaptureIf(OpExprInterpState* ctx, const TensorTuple& inputs, - const TensorTuple& outputs, const AttrValueMap& attrs) const override { + const TensorTuple& outputs, const AttrMap& attrs) const override { StateT* state = dynamic_cast(ctx); CHECK_NOTNULL_OR_RETURN(state); return Capture(state, inputs, outputs, attrs); @@ -65,7 +65,7 @@ class OpExprGradFunction : public OpExprGradFunctionIf { } virtual Maybe Capture(StateT* ctx, const TensorTuple& inputs, const TensorTuple& outputs, - const AttrValueMap& attrs) const = 0; + const AttrMap& attrs) const = 0; virtual Maybe Apply(const StateT* ctx, const TensorTuple& out_grads, TensorTuple* in_grads) const = 0; @@ -88,7 +88,7 @@ class OpExprGradClosure { virtual ~OpExprGradClosure() = default; Maybe Capture(const TensorTuple& inputs, const TensorTuple& outputs, - const AttrValueMap& attrs) const { + const AttrMap& attrs) const { return impl_->CaptureIf(state_.get(), inputs, outputs, attrs); } diff --git a/oneflow/core/framework/op_interpreter.h b/oneflow/core/framework/op_interpreter.h index 3413955f74d..e8b8febeba3 100644 --- a/oneflow/core/framework/op_interpreter.h +++ b/oneflow/core/framework/op_interpreter.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef ONEFLOW_CORE_FRAMEWORK_OP_INTERPRETER_H_ #define ONEFLOW_CORE_FRAMEWORK_OP_INTERPRETER_H_ -#include "oneflow/core/framework/attr_value_map.h" +#include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/framework/tensor.h" #include "oneflow/core/framework/tensor_tuple.h" @@ -45,10 +45,10 @@ class OpExprInterpreter { virtual ~OpExprInterpreter() = default; virtual Maybe Apply(const OpExpr& op, const TensorTuple& inputs, TensorTuple* outputs, - const AttrValueMap& attrs) const = 0; + const AttrMap& attrs) const = 0; Maybe Apply(const OpExpr& op, const TensorTuple& inputs, TensorTuple* outputs) const { - return Apply(op, inputs, outputs, AttrValueMap{}); + return Apply(op, inputs, outputs, AttrMap{}); } }; @@ -64,13 +64,13 @@ class OpExprInterpreter { #define DECLARE_NORMAL_APPLY_FUNC(op_type) \ virtual Maybe ApplyImpl(const op_type##Expr& op_expr, const TensorTuple& inputs, \ - TensorTuple* outputs, const AttrValueMap& attrs) const + TensorTuple* outputs, const AttrMap& attrs) const #define DECLARE_PURE_VIRTUAL_APPLY_FUNC(op_type) DECLARE_NORMAL_APPLY_FUNC(op_type) = 0; #define DECLARE_OVERRIDE_APPLY_FUNC(op_type) \ Maybe ApplyImpl(const op_type##Expr& op_expr, const TensorTuple& inputs, \ - TensorTuple* outputs, const AttrValueMap& attrs) const override; + TensorTuple* outputs, const AttrMap& attrs) const override; class LazyInterpreter : public OpExprInterpreter { public: @@ -78,7 +78,7 @@ class LazyInterpreter : public OpExprInterpreter { virtual ~LazyInterpreter() = default; Maybe Apply(const OpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, - const AttrValueMap& attrs) const override; + const AttrMap& attrs) const override; private: DECLARE_NORMAL_APPLY_FUNC(BuiltinOp); @@ -91,7 +91,7 @@ class EagerInterpreter : public OpExprInterpreter { virtual ~EagerInterpreter() = default; Maybe Apply(const OpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, - const AttrValueMap& attrs) const override; + const AttrMap& attrs) const override; private: FOR_EACH_BUILTIN_OPS(DECLARE_PURE_VIRTUAL_APPLY_FUNC); @@ -129,10 +129,10 @@ class AutogradInterpreter { virtual ~AutogradInterpreter() = default; Maybe Apply(const OpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, - const AttrValueMap& attrs) const; + const AttrMap& attrs) const; Maybe Apply(const OpExpr& op, const TensorTuple& inputs, TensorTuple* outputs) const { - return Apply(op, inputs, outputs, AttrValueMap{}); + return Apply(op, inputs, outputs, AttrMap{}); } private: diff --git a/oneflow/core/framework/op_interpreter/eager_consistent_op_interpreter.cpp b/oneflow/core/framework/op_interpreter/eager_consistent_op_interpreter.cpp index ba732d8d7a0..66ba611f06f 100644 --- a/oneflow/core/framework/op_interpreter/eager_consistent_op_interpreter.cpp +++ b/oneflow/core/framework/op_interpreter/eager_consistent_op_interpreter.cpp @@ -30,7 +30,7 @@ namespace oneflow { namespace one { static Maybe NaiveInterpret(const BuiltinOpExpr& op_expr, const TensorTuple& inputs, - TensorTuple* outputs, const AttrValueMap& attrs) { + TensorTuple* outputs, const AttrMap& attrs) { using namespace std::placeholders; const auto& scope = JUST(GetCurrentScope()); const auto& op_attribute = JUST(OpInterpUtil::InferOpAttribute(op_expr, inputs, attrs)); @@ -54,13 +54,13 @@ static Maybe NaiveInterpret(const BuiltinOpExpr& op_expr, const TensorTupl Maybe EagerConsistentInterpreter::ApplyImpl(const UserOpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, - const AttrValueMap& attrs) const { + const AttrMap& attrs) const { return NaiveInterpret(op_expr, inputs, outputs, attrs); } Maybe EagerConsistentInterpreter::ApplyImpl(const VariableOpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, - const AttrValueMap& attrs) const { + const AttrMap& attrs) const { CHECK_EQ_OR_RETURN(inputs.size(), 0); CHECK_EQ_OR_RETURN(outputs->size(), 1); return NaiveInterpret(op_expr, inputs, outputs, attrs); @@ -90,13 +90,13 @@ static Maybe BuildAndRunMirroredCastInstruction(const BuiltinOpExpr& op_ex Maybe EagerConsistentInterpreter::ApplyImpl(const CastToMirroredOpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, - const AttrValueMap& attrs) const { + const AttrMap& attrs) const { return BuildAndRunMirroredCastInstruction(op_expr, inputs, outputs); } Maybe EagerConsistentInterpreter::ApplyImpl(const CastFromMirroredOpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, - const AttrValueMap& attrs) const { + const AttrMap& attrs) const { return BuildAndRunMirroredCastInstruction(op_expr, inputs, outputs); } @@ -137,13 +137,13 @@ static Maybe BuildAndRunDistributeSplitOrCloneInstruction(const BuiltinOpE Maybe EagerConsistentInterpreter::ApplyImpl(const DistributeSplitOpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, - const AttrValueMap& attrs) const { + const AttrMap& attrs) const { return BuildAndRunDistributeSplitOrCloneInstruction(op_expr, inputs, outputs); } Maybe EagerConsistentInterpreter::ApplyImpl(const DistributeCloneOpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, - const AttrValueMap& attrs) const { + const AttrMap& attrs) const { return BuildAndRunDistributeSplitOrCloneInstruction(op_expr, inputs, outputs); } @@ -178,13 +178,13 @@ static Maybe BuildAndRunDistributeConcatAndAddInstruction(const BuiltinOpE Maybe EagerConsistentInterpreter::ApplyImpl(const DistributeConcatOpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, - const AttrValueMap& attrs) const { + const AttrMap& attrs) const { return BuildAndRunDistributeConcatAndAddInstruction(op_expr, inputs, outputs); } Maybe EagerConsistentInterpreter::ApplyImpl(const DistributeAddOpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, - const AttrValueMap& attrs) const { + const AttrMap& attrs) const { return BuildAndRunDistributeConcatAndAddInstruction(op_expr, inputs, outputs); } diff --git a/oneflow/core/framework/op_interpreter/eager_mirrored_op_interpreter.cpp b/oneflow/core/framework/op_interpreter/eager_mirrored_op_interpreter.cpp index 9ee13a704d3..3d212a60199 100644 --- a/oneflow/core/framework/op_interpreter/eager_mirrored_op_interpreter.cpp +++ b/oneflow/core/framework/op_interpreter/eager_mirrored_op_interpreter.cpp @@ -46,7 +46,7 @@ Maybe NaiveInterpret(const UserOpExpr& user_op_expr, input_eager_blob_objects, const std::shared_ptr>>& output_eager_blob_objects, - const AttrValueMap& attrs, const std::shared_ptr device, + const AttrMap& attrs, const std::shared_ptr device, std::shared_ptr parallel_desc) { const auto kernel = JUST(user_op_expr.MutKernel4Device(*device)); const auto mem_case = kernel->mem_case(); @@ -79,13 +79,13 @@ Maybe GenerateAllocatedEagerBlobObject(DataType data_type, std::shared_ptr parallel_desc = JUST(Device::MakeParallelDescByDevice(*device)); - JUST(NaiveInterpret(*zeros_expr, input_eager_blob_objects, output_eager_blob_objects, - AttrValueMap{}, device, parallel_desc)); + JUST(NaiveInterpret(*zeros_expr, input_eager_blob_objects, output_eager_blob_objects, AttrMap{}, + device, parallel_desc)); return output_eager_blob_objects->at(0); } static Maybe NaiveInterpret(const BuiltinOpExpr& op_expr, const TensorTuple& inputs, - TensorTuple* outputs, const AttrValueMap& attrs) { + TensorTuple* outputs, const AttrMap& attrs) { std::shared_ptr device; if (inputs.empty()) { device = GetDefaultDevice(); @@ -114,13 +114,13 @@ static Maybe NaiveInterpret(const BuiltinOpExpr& op_expr, const TensorTupl Maybe EagerMirroredInterpreter::ApplyImpl(const UserOpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, - const AttrValueMap& attrs) const { + const AttrMap& attrs) const { return NaiveInterpret(op_expr, inputs, outputs, attrs); } Maybe EagerMirroredInterpreter::ApplyImpl(const VariableOpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, - const AttrValueMap& attrs) const { + const AttrMap& attrs) const { CHECK_EQ_OR_RETURN(inputs.size(), 0); CHECK_EQ_OR_RETURN(outputs->size(), 1); return NaiveInterpret(op_expr, inputs, outputs, attrs); @@ -135,13 +135,13 @@ static Maybe BuildAndRunMirroredCastInstruction(const BuiltinOpExpr& op_ex Maybe EagerMirroredInterpreter::ApplyImpl(const CastToMirroredOpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, - const AttrValueMap& attrs) const { + const AttrMap& attrs) const { return BuildAndRunMirroredCastInstruction(op_expr, inputs, outputs); } Maybe EagerMirroredInterpreter::ApplyImpl(const CastFromMirroredOpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, - const AttrValueMap& attrs) const { + const AttrMap& attrs) const { return BuildAndRunMirroredCastInstruction(op_expr, inputs, outputs); } @@ -154,13 +154,13 @@ static Maybe BuildAndRunDistributeSplitOrCloneInstruction(const BuiltinOpE Maybe EagerMirroredInterpreter::ApplyImpl(const DistributeSplitOpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, - const AttrValueMap& attrs) const { + const AttrMap& attrs) const { return BuildAndRunDistributeSplitOrCloneInstruction(op_expr, inputs, outputs); } Maybe EagerMirroredInterpreter::ApplyImpl(const DistributeCloneOpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, - const AttrValueMap& attrs) const { + const AttrMap& attrs) const { return BuildAndRunDistributeSplitOrCloneInstruction(op_expr, inputs, outputs); } @@ -173,13 +173,13 @@ static Maybe BuildAndRunDistributeConcatAndAddInstruction(const BuiltinOpE Maybe EagerMirroredInterpreter::ApplyImpl(const DistributeConcatOpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, - const AttrValueMap& attrs) const { + const AttrMap& attrs) const { return BuildAndRunDistributeConcatAndAddInstruction(op_expr, inputs, outputs); } Maybe EagerMirroredInterpreter::ApplyImpl(const DistributeAddOpExpr& op_expr, const TensorTuple& inputs, TensorTuple* outputs, - const AttrValueMap& attrs) const { + const AttrMap& attrs) const { return BuildAndRunDistributeConcatAndAddInstruction(op_expr, inputs, outputs); } diff --git a/oneflow/core/framework/op_interpreter/op_interpreter.cpp b/oneflow/core/framework/op_interpreter/op_interpreter.cpp index 67fada00bb8..097241171df 100644 --- a/oneflow/core/framework/op_interpreter/op_interpreter.cpp +++ b/oneflow/core/framework/op_interpreter/op_interpreter.cpp @@ -34,7 +34,7 @@ namespace oneflow { namespace one { Maybe LazyInterpreter::Apply(const OpExpr& op_expr, const TensorTuple& inputs, - TensorTuple* outputs, const AttrValueMap& attrs) const { + TensorTuple* outputs, const AttrMap& attrs) const { #define APPLY_IF(op_type) \ if (const auto* op = dynamic_cast(&op_expr)) { \ return ApplyImpl(*op, inputs, outputs, attrs); \ @@ -49,7 +49,7 @@ Maybe LazyInterpreter::Apply(const OpExpr& op_expr, const TensorTuple& inp } Maybe LazyInterpreter::ApplyImpl(const BuiltinOpExpr& op_expr, const TensorTuple& inputs, - TensorTuple* outputs, const AttrValueMap& attrs) const { + TensorTuple* outputs, const AttrMap& attrs) const { CHECK_EQ_OR_RETURN(inputs.size(), op_expr.input_size()); const auto& scope = JUST(GetCurrentScope()); auto op_conf = JUST(OpInterpUtil::GenBuiltinOpConf(op_expr, attrs)); @@ -93,14 +93,14 @@ Maybe LazyInterpreter::ApplyImpl(const BuiltinOpExpr& op_expr, const Tenso } Maybe LazyInterpreter::ApplyImpl(const FunctionOpExpr& op_expr, const TensorTuple& inputs, - TensorTuple* outputs, const AttrValueMap& attrs) const { + TensorTuple* outputs, const AttrMap& attrs) const { // TODO(hjchen2) UNIMPLEMENTED(); return Maybe::Ok(); } Maybe EagerInterpreter::Apply(const OpExpr& op_expr, const TensorTuple& inputs, - TensorTuple* outputs, const AttrValueMap& attrs) const { + TensorTuple* outputs, const AttrMap& attrs) const { #define APPLY_IF(op_type) \ if (const auto* op = dynamic_cast(&op_expr)) { \ return ApplyImpl(*op, inputs, outputs, attrs); \ @@ -122,7 +122,7 @@ Maybe EagerInterpreter::Apply(const OpExpr& op_expr, const TensorTuple& in } Maybe EagerInterpreter::ApplyImpl(const FunctionOpExpr& op_expr, const TensorTuple& inputs, - TensorTuple* outputs, const AttrValueMap& attrs) const { + TensorTuple* outputs, const AttrMap& attrs) const { // TODO(hjchen2) UNIMPLEMENTED(); return Maybe::Ok(); @@ -144,7 +144,7 @@ Maybe DetermineRequiresGrad(TensorTuple* outputs, const bool& requires_gra } // namespace Maybe AutogradInterpreter::Apply(const OpExpr& op_expr, const TensorTuple& inputs, - TensorTuple* outputs, const AttrValueMap& attrs) const { + TensorTuple* outputs, const AttrMap& attrs) const { bool requires_grad = false; { autograd::AutoGradMode mode(false); diff --git a/oneflow/core/framework/op_interpreter/op_interpreter_util.cpp b/oneflow/core/framework/op_interpreter/op_interpreter_util.cpp index 54abf6ef767..f36fb8a116a 100644 --- a/oneflow/core/framework/op_interpreter/op_interpreter_util.cpp +++ b/oneflow/core/framework/op_interpreter/op_interpreter_util.cpp @@ -66,7 +66,7 @@ std::shared_ptr BuildLazyInterpreter() { template<> /*static*/ Maybe OpInterpUtil::Dispatch(const OpExpr& op_expr, const TensorTuple& inputs, - const AttrValueMap& attrs) { + const AttrMap& attrs) { auto outputs = std::make_shared(op_expr.output_size()); JUST(GetInterpreter())->Apply(op_expr, inputs, outputs.get(), attrs); return outputs; @@ -75,7 +75,7 @@ template<> template<> /*static*/ Maybe OpInterpUtil::Dispatch(const OpExpr& op_expr, const TensorTuple& inputs, - const AttrValueMap& attrs) { + const AttrMap& attrs) { return JUST(Dispatch(op_expr, inputs, attrs))->at(0); } @@ -94,7 +94,7 @@ template<> /*static*/ Maybe OpInterpUtil::InferOpAttribute(const BuiltinOpExpr& op_expr, const TensorTuple& inputs, - const AttrValueMap& attrs) { + const AttrMap& attrs) { const auto& scope = JUST(GetCurrentScope()); auto op_conf = JUST(GenBuiltinOpConf(op_expr, attrs)); int64_t symbol_id = JUST(scope->symbol_id()); @@ -133,7 +133,7 @@ using Bn2BlobObjectMap = HashMap OpInterpUtil::GenBuiltinOpConf(const BuiltinOpExpr& op_expr, - const AttrValueMap& attrs) { + const AttrMap& attrs) { auto op_conf = std::make_shared(); op_expr.BuildOpConf(op_conf.get(), attrs); return op_conf; diff --git a/oneflow/core/framework/op_interpreter/op_interpreter_util.h b/oneflow/core/framework/op_interpreter/op_interpreter_util.h index 6a08a26765e..3cf4d413670 100644 --- a/oneflow/core/framework/op_interpreter/op_interpreter_util.h +++ b/oneflow/core/framework/op_interpreter/op_interpreter_util.h @@ -16,7 +16,7 @@ limitations under the License. #ifndef ONEFLOW_CORE_FRAMEWORK_OP_INTERPRETER_UTIL_H_ #define ONEFLOW_CORE_FRAMEWORK_OP_INTERPRETER_UTIL_H_ -#include "oneflow/core/framework/attr_value_map.h" +#include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/instructions_builder.h" #include "oneflow/core/framework/op_arg_util.h" #include "oneflow/core/framework/op_expr.h" @@ -34,23 +34,20 @@ class OpInterpUtil { static Maybe GetInterpreter(); template - static Maybe Dispatch(const OpExpr& op_expr, const TensorTuple& inputs, - const AttrValueMap& attrs); + static Maybe Dispatch(const OpExpr& op_expr, const TensorTuple& inputs, const AttrMap& attrs); template static Maybe Dispatch(const OpExpr& op_expr, const TensorTuple& inputs) { - return Dispatch(op_expr, inputs, AttrValueMap{}); + return Dispatch(op_expr, inputs, AttrMap{}); } - static Maybe GenBuiltinOpConf(const BuiltinOpExpr& op_expr, - const AttrValueMap& attrs); + static Maybe GenBuiltinOpConf(const BuiltinOpExpr& op_expr, const AttrMap& attrs); static Maybe AddOpAndInferOpAttribute(const OperatorConf& op_conf, const bool is_mirrored_strategy_enabled); static Maybe InferOpAttribute(const BuiltinOpExpr& op_expr, - const TensorTuple& inputs, - const AttrValueMap& attrs); + const TensorTuple& inputs, const AttrMap& attrs); static Maybe>> MakeBn2BlobObjectMap(const std::vector& indexed_ibns, const TensorTuple& inputs); diff --git a/oneflow/core/framework/user_op_conf_trait.h b/oneflow/core/framework/user_op_conf_trait.h index 35a3b3a121b..130f03c4160 100644 --- a/oneflow/core/framework/user_op_conf_trait.h +++ b/oneflow/core/framework/user_op_conf_trait.h @@ -18,7 +18,7 @@ limitations under the License. #define ONEFLOW_CORE_FRAMEWORK_USER_OP_CONF_TRAIT_H_ #include "oneflow/core/framework/user_op_conf.h" -#include "oneflow/core/framework/attr_value_map.h" +#include "oneflow/core/framework/attr_map.h" namespace oneflow { namespace user_op { @@ -46,12 +46,15 @@ class UserOpConfTrait { Maybe GetAttr(const std::string& attr_name) const { const auto& it = attrs_.find(attr_name); CHECK_OR_RETURN(it != attrs_.end()) << "The op has no attribute named " << attr_name; - return std::dynamic_pointer_cast>(it->second)->val(); + return dynamic_cast*>(it->second.get())->val(); } template - Maybe GetAttr(const std::string& attr_name, const AttrValueMap& priority_attrs) const { - if (priority_attrs.count(attr_name)) { return priority_attrs.GetAttr(attr_name); } + Maybe GetAttr(const std::string& attr_name, const AttrMap& priority_attrs) const { + const auto& it = priority_attrs.find(attr_name); + if (it != priority_attrs.end()) { + return dynamic_cast*>(it->second.get())->val(); + } return GetAttr(attr_name); } @@ -60,7 +63,7 @@ class UserOpConfTrait { std::string op_type_name_; HashMap> inputs_; HashMap> outputs_; - HashMap> attrs_; + HashMap> attrs_; }; } // namespace user_op diff --git a/oneflow/python/framework/op_expr_util.py b/oneflow/python/framework/op_expr_util.py index 9402c797a04..a8cad61c20e 100644 --- a/oneflow/python/framework/op_expr_util.py +++ b/oneflow/python/framework/op_expr_util.py @@ -27,7 +27,7 @@ def user_op_expr_call(self, *args, **kwargs): arg.determine() args[i] = arg._local_or_consistent_tensor - attrs = oneflow._oneflow_internal.AttrValueMap() + attrs = oneflow._oneflow_internal.MutableCfgAttrMap() for attr_name, attr_value in kwargs.items(): assert isinstance(attr_name, str) attrs[attr_name] = convert_to_user_attr_value( diff --git a/oneflow/user/kernels/stateful_local_opkernel.cpp b/oneflow/user/kernels/stateful_local_opkernel.cpp index 77d60ca808e..f83be2c27d0 100644 --- a/oneflow/user/kernels/stateful_local_opkernel.cpp +++ b/oneflow/user/kernels/stateful_local_opkernel.cpp @@ -14,12 +14,14 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/user/kernels/stateful_local_opkernel.h" + +#include "oneflow/core/framework/attr_value_accessor.h" #include "oneflow/core/framework/user_op_conf.h" #include "oneflow/core/framework/user_op_registry_manager.h" #include "oneflow/core/kernel/kernel.h" #include "oneflow/core/eager/eager_blob_object.h" #include "oneflow/core/framework/attr_value_accessor.h" -#include "oneflow/core/framework/attr_value_map.h" +#include "oneflow/core/framework/attr_map.h" namespace oneflow { namespace one { @@ -433,14 +435,14 @@ LocalUserKernelComputeContext* StatefulOpKernel::UpdateComputeContext(EagerBlobO return compute_ctx_.get(); } -void StatefulOpKernel::ResetDynamicOpAttrs(const AttrValueMap& attrs) { +void StatefulOpKernel::ResetDynamicOpAttrs(const AttrMap& attrs) { // TODO(jianhao): get attr directly from attrs, remove the copy of OperatorConf and // UserOpConfWrapper here std::shared_ptr op_conf = std::make_shared(*op_conf_); auto* user_op_conf = op_conf->mutable_user_conf(); for (const auto& it : attrs) { AttrValue attr_val; - it.second->ToProto(&attr_val); + user_op::AttrValueUtil::ToProtoAttrValue(*it.second, &attr_val); (*(user_op_conf->mutable_attr()))[it.first] = attr_val; } *user_op_conf_ = user_op::UserOpConfWrapper(op_conf); diff --git a/oneflow/user/kernels/stateful_local_opkernel.h b/oneflow/user/kernels/stateful_local_opkernel.h index bb83715c3ce..7ff5fc723a7 100644 --- a/oneflow/user/kernels/stateful_local_opkernel.h +++ b/oneflow/user/kernels/stateful_local_opkernel.h @@ -23,7 +23,7 @@ limitations under the License. namespace oneflow { -class AttrValueMap; +class AttrMap; namespace vm { struct LocalCallOpKernelUtil; @@ -272,7 +272,7 @@ class StatefulOpKernel final { UpdateInferContext(nullptr, nullptr); } - void ResetDynamicOpAttrs(const AttrValueMap& attrs); + void ResetDynamicOpAttrs(const AttrMap& attrs); private: friend struct vm::LocalCallOpKernelUtil;