Skip to content

Commit

Permalink
Dev refactor attr value map (#4755)
Browse files Browse the repository at this point in the history
* Refactor

* Draft

* Refactor

* Add MutableCfgAttrValueMap

* implement AttrValueMap and ComposedAttrValueMap (#4767)

* Attr value util (#4773)

* implement AttrValueMap and ComposedAttrValueMap

* AttrValueUtil::ToProtoAttrValue

* Fix compile

* Fix compilation

* Rename AttrValueMap by AttrMap.

Co-authored-by: lixinqi <lixinqi0703106@163.com>
Co-authored-by: hjchen2 <hjchen2>
Co-authored-by: Li Xinqi <lixinqi2010@gmail.com>
Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
Former-commit-id: 354652e
  • Loading branch information
4 people authored Apr 29, 2021
1 parent ba18214 commit 4bc260a
Show file tree
Hide file tree
Showing 37 changed files with 436 additions and 233 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,29 +17,28 @@ limitations under the License.
#include <pybind11/stl.h>
#include "oneflow/api/python/of_api_registry.h"

#include "oneflow/core/framework/attr_value_map.h"
#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/framework/user_op_attr.cfg.h"

namespace py = pybind11;

namespace oneflow {

ONEFLOW_API_PYBIND11_MODULE("", m) {
py::class_<AttrValueMap, std::shared_ptr<AttrValueMap>>(m, "AttrValueMap")
py::class_<MutableCfgAttrMap, std::shared_ptr<MutableCfgAttrMap>>(m, "MutableCfgAttrMap")
.def(py::init<>())
.def("__setitem__",
[](AttrValueMap* m, const std::string& attr_name,
[](MutableCfgAttrMap* m, const std::string& attr_name,
const std::shared_ptr<cfg::AttrValue>& attr_value) {
m->SetAttr(attr_name, attr_value).GetOrThrow();
})
.def("__getitem__",
[](const AttrValueMap& m, const std::string& attr_name) {
m.GetAttr<cfg::AttrValue>(attr_name).GetOrThrow();
})
[](const MutableCfgAttrMap& m, const std::string& attr_name) { return m.at(attr_name); })
.def(
"__iter__", [](const AttrValueMap& m) { return py::make_iterator(m.begin(), m.end()); },
"__iter__",
[](const MutableCfgAttrMap& m) { return py::make_iterator(m.begin(), m.end()); },
py::keep_alive<0, 1>())
.def("__len__", [](const AttrValueMap& m) { return m.size(); });
.def("__len__", [](const MutableCfgAttrMap& m) { return m.size(); });
}

} // namespace oneflow
10 changes: 5 additions & 5 deletions oneflow/api/python/framework/op_expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ limitations under the License.
#include <pybind11/stl.h>
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/common/protobuf.h"
#include "oneflow/core/framework/attr_value_map.h"
#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/framework/op_interpreter.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
Expand All @@ -32,7 +32,7 @@ namespace oneflow {
namespace {

Maybe<one::TensorTuple> Interpret(const one::OpExpr& op, const one::TensorTuple& inputs,
const AttrValueMap& attrs) {
const AttrMap& attrs) {
CHECK_EQ_OR_RETURN(op.input_size(), inputs.size())
<< "The operation requires " << op.input_size() << " inputs, but " << inputs.size()
<< " is given.";
Expand All @@ -44,7 +44,7 @@ Maybe<one::TensorTuple> Interpret(const one::OpExpr& op, const one::TensorTuple&

Maybe<one::TensorTuple> Interpret(const one::OpExpr& op,
const std::vector<std::shared_ptr<one::Tensor>>& inputs,
const AttrValueMap& attrs) {
const AttrMap& attrs) {
one::TensorTuple input_list(inputs.size());
for (int i = 0; i < inputs.size(); ++i) { input_list[i] = inputs[i]; }
return JUST(Interpret(op, input_list, attrs));
Expand Down Expand Up @@ -76,11 +76,11 @@ ONEFLOW_API_PYBIND11_MODULE("one", m) {
.def_property_readonly("output_size", &one::OpExpr::output_size)
.def("apply",
[](const one::OpExpr& op_expr, const std::vector<std::shared_ptr<one::Tensor>>& inputs,
const AttrValueMap& attrs) {
const MutableCfgAttrMap& attrs) {
return Interpret(op_expr, inputs, attrs).GetPtrOrThrow();
})
.def("apply", [](const one::OpExpr& op_expr, const one::TensorTuple& inputs,
const AttrValueMap& attrs) {
const MutableCfgAttrMap& attrs) {
return Interpret(op_expr, inputs, attrs).GetPtrOrThrow();
});

Expand Down
6 changes: 3 additions & 3 deletions oneflow/core/autograd/gradient_funcs/batch_gather.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class BatchGather : public OpExprGradFunction<BatchGatherInterpState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(BatchGatherInterpState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrValueMap& attrs) const override;
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const BatchGatherInterpState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;

Expand All @@ -49,7 +49,7 @@ Maybe<void> BatchGather::Init(const OpExpr& op) {
}

Maybe<void> BatchGather::Capture(BatchGatherInterpState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrValueMap& attrs) const {
const TensorTuple& outputs, const AttrMap& attrs) const {
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
const auto& in_shape = inputs.at(0)->shape();
Expand All @@ -64,7 +64,7 @@ Maybe<void> BatchGather::Apply(const BatchGatherInterpState* ctx, const TensorTu
in_grads->resize(2);
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
const auto& indices = ctx->SavedTensors().at(0);
AttrValueMap attrs;
MutableAttrMap attrs;
JUST(attrs.SetAttr<int32_t>("num_segments", ctx->num_segments));
in_grads->at(0) = JUST(OpInterpUtil::Dispatch<Tensor>(*bw_unsorted_batch_segment_sum_op_,
{out_grads.at(0), indices}, attrs));
Expand Down
4 changes: 2 additions & 2 deletions oneflow/core/autograd/gradient_funcs/bias_add.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class BiasAdd : public OpExprGradFunction<BiasAddInterpState> {
}

Maybe<void> Capture(BiasAddInterpState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrValueMap& attrs) const override {
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 2);
ctx->input_requires_grad = inputs.at(0)->requires_grad();
ctx->bias_requires_grad = inputs.at(1)->requires_grad();
Expand All @@ -60,7 +60,7 @@ class BiasAdd : public OpExprGradFunction<BiasAddInterpState> {
for (int i = 0; i < num_axes; ++i) {
if (i != axis_) { reduce_axes_vec.push_back(i); }
}
AttrValueMap attrs;
MutableAttrMap attrs;
JUST(attrs.SetAttr<std::vector<int32_t>>("axis", reduce_axes_vec));
in_grads->at(1) =
JUST(OpInterpUtil::Dispatch<Tensor>(*backward_bias_op_, {out_grads.at(0)}, attrs));
Expand Down
4 changes: 2 additions & 2 deletions oneflow/core/autograd/gradient_funcs/broadcast_binary_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class ReduceSumLikeModule {
const auto& in_shape = *(input->shape());
const auto& like_shape = *(like->shape());
TensorTuple inputs{input};
AttrValueMap attrs;
MutableAttrMap attrs;
std::shared_ptr<OpExpr> op = identity_op_;
if (in_shape != like_shape) {
const Shape& left_extended_shape =
Expand Down Expand Up @@ -77,7 +77,7 @@ class BroadcastBinaryGrad : public OpExprGradFunction<OpExprInterpState> {
}

Maybe<void> Capture(OpExprInterpState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,
const AttrValueMap& attrs) const override {
const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 2);
CHECK_EQ_OR_RETURN(outputs.size(), 1);
ctx->SaveTensorForBackward(inputs.at(0));
Expand Down
4 changes: 2 additions & 2 deletions oneflow/core/autograd/gradient_funcs/cast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class Cast : public OpExprGradFunction<OpExprInterpState> {
}

Maybe<void> Capture(OpExprInterpState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,
const AttrValueMap& attrs) const override {
const AttrMap& attrs) const override {
ctx->SaveTensorForBackward(inputs.at(0));
return Maybe<void>::Ok();
}
Expand All @@ -44,7 +44,7 @@ class Cast : public OpExprGradFunction<OpExprInterpState> {
TensorTuple* in_grads) const override {
const auto& x = ctx->SavedTensors().at(0);
in_grads->resize(1);
AttrValueMap attrs;
MutableAttrMap attrs;
JUST(attrs.SetAttr<DataType>("dtype", x->dtype()->data_type()));
in_grads->at(0) = JUST(OpInterpUtil::Dispatch<Tensor>(*grad_op_, {out_grads.at(0)}, attrs));
return Maybe<void>::Ok();
Expand Down
4 changes: 2 additions & 2 deletions oneflow/core/autograd/gradient_funcs/deconv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class DeConvolutionNd : public OpExprGradFunction<DeConvolutionNdInterpState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(DeConvolutionNdInterpState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrValueMap& attrs) const override;
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const DeConvolutionNdInterpState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;

Expand Down Expand Up @@ -72,7 +72,7 @@ Maybe<void> DeConvolutionNd::Init(const OpExpr& op) {
}

Maybe<void> DeConvolutionNd::Capture(DeConvolutionNdInterpState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrValueMap& attrs) const {
const TensorTuple& outputs, const AttrMap& attrs) const {
ctx->activation_requires_grad = inputs.at(0)->requires_grad();
ctx->weight_requires_grad = inputs.at(1)->requires_grad();
if (ctx->activation_requires_grad) {
Expand Down
4 changes: 2 additions & 2 deletions oneflow/core/autograd/gradient_funcs/default.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class DefaultOpExprGradFunction : public OpExprGradFunction<DefaultOpExprInterpS
Maybe<void> Init(const OpExpr& op) override;

Maybe<void> Capture(DefaultOpExprInterpState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrValueMap& attrs) const override;
const TensorTuple& outputs, const AttrMap& attrs) const override;

Maybe<void> Apply(const DefaultOpExprInterpState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
Expand Down Expand Up @@ -285,7 +285,7 @@ Maybe<void> DefaultOpExprGradFunction::UpdateRequiresBackward(DefaultOpExprInter
Maybe<void> DefaultOpExprGradFunction::Capture(DefaultOpExprInterpState* ctx,
const TensorTuple& inputs,
const TensorTuple& outputs,
const AttrValueMap& attrs) const {
const AttrMap& attrs) const {
CHECK_OR_RETURN(attrs.empty())
<< "The default op expr gradient func does not support dynamic attributes.";
JUST(UpdateRequiresBackward(ctx, inputs));
Expand Down
6 changes: 3 additions & 3 deletions oneflow/core/autograd/gradient_funcs/layer_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class LayerNorm : public OpExprGradFunction<LayerNormInterpState> {
Maybe<void> Init(const OpExpr& op) override;

Maybe<void> Capture(LayerNormInterpState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrValueMap& attrs) const override;
const TensorTuple& outputs, const AttrMap& attrs) const override;

Maybe<void> Apply(const LayerNormInterpState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;
Expand Down Expand Up @@ -68,7 +68,7 @@ Maybe<void> LayerNorm::Init(const OpExpr& op) {
}

Maybe<void> LayerNorm::Capture(LayerNormInterpState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrValueMap& attrs) const {
const TensorTuple& outputs, const AttrMap& attrs) const {
CHECK_EQ_OR_RETURN(inputs.size(), center_ + scale_ + 1);
CHECK_EQ_OR_RETURN(inputs.size(), scale_ + 3);
ctx->has_beta_diff = center_ && inputs.at(1)->requires_grad();
Expand Down Expand Up @@ -122,7 +122,7 @@ Maybe<void> LayerNorm::Apply(const LayerNormInterpState* ctx, const TensorTuple&
const auto& x = saved_tensors.at(offset);
const auto& mean = saved_tensors.at(offset + 1);
const auto& inv_variance = saved_tensors.at(offset + 2);
AttrValueMap attrs;
MutableAttrMap attrs;
JUST(attrs.SetAttr<int64_t>("begin_norm_axis", begin_norm_axis));
in_grads->at(0) =
JUST(OpInterpUtil::Dispatch<Tensor>(*x_grad_op_, {x, mean, inv_variance, dy}, attrs));
Expand Down
4 changes: 2 additions & 2 deletions oneflow/core/autograd/gradient_funcs/normalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class NormalizationGrad : public OpExprGradFunction<NormalizationGradInterpState
}

Maybe<void> Capture(NormalizationGradInterpState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrValueMap& attrs) const override {
const TensorTuple& outputs, const AttrMap& attrs) const override {
ctx->is_training = JUST(op_trait_->GetAttr<bool>("training", attrs));
ctx->SaveTensorForBackward(inputs.at(0)); // x
ctx->SaveTensorForBackward(inputs.at(3)); // gamma
Expand Down Expand Up @@ -118,7 +118,7 @@ class NormalizationGrad : public OpExprGradFunction<NormalizationGradInterpState
dim_vec.push_back(x->shape()->At(axis_));
}
}
AttrValueMap shape_attr;
MutableAttrMap shape_attr;
shape_attr.SetAttr<Shape>("shape", Shape(dim_vec));
const auto& reshaped_gamma =
JUST(OpInterpUtil::Dispatch<Tensor>(*reshape_gamma_op_, {gamma}, shape_attr));
Expand Down
2 changes: 1 addition & 1 deletion oneflow/core/autograd/gradient_funcs/reshape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class ReshapeOpExprGrad : public OpExprGradFunction<OpExprInterpState> {
}

Maybe<void> Capture(OpExprInterpState* ctx, const TensorTuple& inputs, const TensorTuple& outputs,
const AttrValueMap& attrs) const override {
const AttrMap& attrs) const override {
ctx->SaveTensorForBackward(inputs.at(0));
return Maybe<void>::Ok();
}
Expand Down
6 changes: 3 additions & 3 deletions oneflow/core/autograd/gradient_funcs/split_like.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class SplitLike : public OpExprGradFunction<SplitLikeInterpState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(SplitLikeInterpState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrValueMap& attrs) const override;
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const SplitLikeInterpState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;

Expand Down Expand Up @@ -61,7 +61,7 @@ Maybe<void> SplitLike::Init(const OpExpr& op) {
}

Maybe<void> SplitLike::Capture(SplitLikeInterpState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrValueMap& attrs) const {
const TensorTuple& outputs, const AttrMap& attrs) const {
CHECK_EQ_OR_RETURN(inputs.size(), outputs.size() + 1);
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
Expand Down Expand Up @@ -92,7 +92,7 @@ Maybe<void> SplitLike::Apply(const SplitLikeInterpState* ctx, const TensorTuple&
inputs.push_back(zero_grad);
}
}
AttrValueMap concat_attrs;
MutableAttrMap concat_attrs;
concat_attrs.SetAttr<int>("axis", axis_);
concat_attrs.SetAttr<int>("max_dim_size", ctx->max_dim_size);
in_grads->at(0) = JUST(OpInterpUtil::Dispatch<Tensor>(*concat_op_, inputs, concat_attrs));
Expand Down
19 changes: 9 additions & 10 deletions oneflow/core/autograd/gradient_funcs/tensor_scalar_binary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class TensorScalarAddOrSub : public OpExprGradFunction<TensorScalarInterpState>

Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(TensorScalarInterpState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrValueMap& attrs) const override;
const TensorTuple& outputs, const AttrMap& attrs) const override;

protected:
std::shared_ptr<OpExpr> identity_op_;
Expand All @@ -56,8 +56,7 @@ Maybe<void> TensorScalarAddOrSub::Init(const OpExpr& op) {
}

Maybe<void> TensorScalarAddOrSub::Capture(TensorScalarInterpState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs,
const AttrValueMap& attrs) const {
const TensorTuple& outputs, const AttrMap& attrs) const {
ctx->x_requires_grad = inputs.at(0)->requires_grad();
ctx->scalar_requires_grad = inputs.at(1)->requires_grad();
return Maybe<void>::Ok();
Expand All @@ -75,7 +74,7 @@ class TensorScalarAdd : public TensorScalarAddOrSub {
int32_t num_axes = out_grads.at(0)->shape()->NumAxes();
std::vector<int32_t> axes_vec(num_axes);
std::iota(axes_vec.begin(), axes_vec.end(), 0);
AttrValueMap attrs;
MutableAttrMap attrs;
JUST(attrs.SetAttr<std::vector<int32_t>>("axis", axes_vec));
in_grads->at(1) =
JUST(OpInterpUtil::Dispatch<Tensor>(*reduce_sum_op_, {out_grads.at(0)}, attrs));
Expand All @@ -97,7 +96,7 @@ class TensorScalarSub : public TensorScalarAddOrSub {
int32_t num_axes = out_grads.at(0)->shape()->NumAxes();
std::vector<int32_t> axes_vec(num_axes);
std::iota(axes_vec.begin(), axes_vec.end(), 0);
AttrValueMap attrs;
MutableAttrMap attrs;
JUST(attrs.SetAttr<std::vector<int32_t>>("axis", axes_vec));
const auto& reduce_sum =
JUST(OpInterpUtil::Dispatch<Tensor>(*reduce_sum_op_, {out_grads.at(0)}, attrs));
Expand All @@ -114,7 +113,7 @@ class TensorScalarMul : public OpExprGradFunction<TensorScalarInterpState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(TensorScalarInterpState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrValueMap& attrs) const override;
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const TensorScalarInterpState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;

Expand All @@ -136,7 +135,7 @@ Maybe<void> TensorScalarMul::Init(const OpExpr& op) {
}

Maybe<void> TensorScalarMul::Capture(TensorScalarInterpState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrValueMap& attrs) const {
const TensorTuple& outputs, const AttrMap& attrs) const {
ctx->x_requires_grad = inputs.at(0)->requires_grad();
ctx->scalar_requires_grad = inputs.at(1)->requires_grad();
if (ctx->x_requires_grad) { ctx->SaveTensorForBackward(inputs.at(1)); }
Expand All @@ -158,7 +157,7 @@ Maybe<void> TensorScalarMul::Apply(const TensorScalarInterpState* ctx, const Ten
int32_t num_axes = out_grads.at(0)->shape()->NumAxes();
std::vector<int32_t> axes_vec(num_axes);
std::iota(axes_vec.begin(), axes_vec.end(), 0);
AttrValueMap attrs;
MutableAttrMap attrs;
JUST(attrs.SetAttr<std::vector<int32_t>>("axis", axes_vec));
in_grads->at(1) = JUST(OpInterpUtil::Dispatch<Tensor>(*reduce_sum_op_, {y}, attrs));
}
Expand All @@ -171,7 +170,7 @@ class TensorScalarDiv : public OpExprGradFunction<TensorScalarInterpState> {
public:
Maybe<void> Init(const OpExpr& op) override;
Maybe<void> Capture(TensorScalarInterpState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrValueMap& attrs) const override;
const TensorTuple& outputs, const AttrMap& attrs) const override;
Maybe<void> Apply(const TensorScalarInterpState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override;

Expand All @@ -190,7 +189,7 @@ Maybe<void> TensorScalarDiv::Init(const OpExpr& op) {
}

Maybe<void> TensorScalarDiv::Capture(TensorScalarInterpState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrValueMap& attrs) const {
const TensorTuple& outputs, const AttrMap& attrs) const {
ctx->x_requires_grad = inputs.at(0)->requires_grad();
ctx->scalar_requires_grad = inputs.at(1)->requires_grad();
if (ctx->x_requires_grad || ctx->scalar_requires_grad) {
Expand Down
6 changes: 6 additions & 0 deletions oneflow/core/common/shape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/shape.h"
#include "oneflow/core/common/shape.cfg.h"
#include "oneflow/core/common/shape_view.h"
#include "oneflow/core/common/protobuf.h"

Expand Down Expand Up @@ -58,6 +59,11 @@ Shape::Shape(const ShapeProto& shape_proto) {
UpdateElemCnt();
}

Shape::Shape(const cfg::ShapeProto& shape_proto) {
dim_vec_.assign(shape_proto.dim().begin(), shape_proto.dim().end());
UpdateElemCnt();
}

Shape& Shape::operator=(const Shape& shape) {
dim_vec_ = shape.dim_vec_;
UpdateElemCnt();
Expand Down
Loading

0 comments on commit 4bc260a

Please sign in to comment.