Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Attr value util #4773

Merged
merged 3 commits into from
Apr 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions oneflow/core/common/shape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/shape.h"
#include "oneflow/core/common/shape.cfg.h"
#include "oneflow/core/common/shape_view.h"
#include "oneflow/core/common/protobuf.h"

Expand Down Expand Up @@ -58,6 +59,11 @@ Shape::Shape(const ShapeProto& shape_proto) {
UpdateElemCnt();
}

Shape::Shape(const cfg::ShapeProto& shape_proto) {
dim_vec_.assign(shape_proto.dim().begin(), shape_proto.dim().end());
UpdateElemCnt();
}

Shape& Shape::operator=(const Shape& shape) {
dim_vec_ = shape.dim_vec_;
UpdateElemCnt();
Expand Down
5 changes: 5 additions & 0 deletions oneflow/core/common/shape.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,18 @@ namespace oneflow {

class ShapeView;

namespace cfg {
class ShapeProto;
}

class Shape final {
public:
// OF_DISALLOW_COPY_AND_MOVE(Shape);
Shape() : elem_cnt_(0) {}
explicit Shape(const DimVector& dim_vec);
explicit Shape(DimVector&& dim_vec);
explicit Shape(const ShapeProto& shape_proto);
explicit Shape(const cfg::ShapeProto& shape_proto);
Shape(const std::initializer_list<int64_t>& dim_vec);
~Shape() = default;
Shape& operator=(const Shape& shape);
Expand Down
53 changes: 51 additions & 2 deletions oneflow/core/framework/attr_value_accessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
#include "oneflow/core/common/util.h"
#include "oneflow/core/common/shape.h"
#include "oneflow/core/common/protobuf.h"
#include "oneflow/core/common/user_op_conf.h"

namespace oneflow {

Expand All @@ -29,6 +30,10 @@ namespace user_op {
CHECK(val.has_##field()); \
return val.field(); \
} \
cpp_type AttrValueAccessor<cpp_type>::Attr(const cfg::AttrValue& val) { \
CHECK(val.has_##field()); \
return val.field(); \
} \
template<> \
void AttrValueAccessor<cpp_type>::Attr(const cpp_type& cpp_val, AttrValue* attr_val) { \
attr_val->set_##field(cpp_val); \
Expand All @@ -48,6 +53,12 @@ template<>
Shape AttrValueAccessor<Shape>::Attr(const AttrValue& val) {
return Shape(val.at_shape());
}

template<>
Shape AttrValueAccessor<Shape>::Attr(const cfg::AttrValue& val) {
return Shape(val.at_shape());
}

template<>
void AttrValueAccessor<Shape>::Attr(const Shape& cpp_val, AttrValue* attr_val) {
cpp_val.ToProto(attr_val->mutable_at_shape());
Expand All @@ -59,6 +70,10 @@ void AttrValueAccessor<Shape>::Attr(const Shape& cpp_val, AttrValue* attr_val) {
cpp_type AttrValueAccessor<cpp_type>::Attr(const AttrValue& val) { \
return PbRf2StdVec<cpp_type::value_type>(val.field().val()); \
} \
cpp_type AttrValueAccessor<cpp_type>::Attr(const cfg::AttrValue& val) { \
const auto& rp_val = val.field().val(); \
return cpp_type(rp_val.begin(), rp_val.end()); \
} \
template<> \
void AttrValueAccessor<cpp_type>::Attr(const cpp_type& cpp_val, AttrValue* attr_val) { \
*(attr_val->mutable_##field()->mutable_val()) = StdVec2PbRf<cpp_type::value_type>(cpp_val); \
Expand All @@ -80,6 +95,15 @@ OF_PP_FOR_EACH_TUPLE(LIST_BASIC_ATTR_SEQ_ENTRY, LIST_BASIC_ATTR_SEQ)
return ret; \
} \
template<> \
cpp_type AttrValueAccessor<cpp_type>::Attr(const cfg::AttrValue& val) { \
std::vector<cpp_type::value_type> ret; \
ret.reserve(val.field().val_size()); \
for (const auto& value : val.field().val()) { \
ret.emplace_back(static_cast<cpp_type::value_type>(value)); \
} \
return ret; \
} \
template<> \
void AttrValueAccessor<cpp_type>::Attr(const cpp_type& cpp_val, AttrValue* attr_val) { \
using proto_type = std::remove_reference_t<decltype(attr_val->field().val())>::value_type; \
std::vector<proto_type> vec; \
Expand Down Expand Up @@ -120,17 +144,42 @@ void AttrValueAccessor<std::vector<std::string>>::Attr(const std::vector<std::st
*(attr_val->mutable_at_list_string()->mutable_val()) = StdVec2PbRpf<std::string>(cpp_val);
}

Maybe<AttrVal> MakeCppAttrValByCfgAttrValue(const cfg::AttrValue& cfg_attr_value) {
template<typename ProtoT>
Maybe<AttrVal> MakeCppAttrValueFromProtoOrCfgAttrValue(const ProtoT& cfg_attr_value) {
switch (static_cast<int>(cfg_attr_value.value_case())) {
#define MAKE_ENTRY(field, T, attr_type) \
case static_cast<int>(attr_type): return AttrValueAccessor<T>(cfg_attr_value);
case static_cast<int>(attr_type): return AttrValueAccessor<T>::Attr(cfg_attr_value);
OF_PP_FOR_EACH_TUPLE(MAKE_ENTRY, ATTR_SEQ);
#undef MAKE_ENTRY
default:
OF_UNIMPLEMENTED();
}
}

/*static*/Maybe<AttrVal> AttrValueUtil::ToCppAttrValue(const AttrValue& proto_attr_value) {
return MakeCppAttrValueFromProtoOrCfgAttrValue(proto_attr_value);
}

/*static*/Maybe<AttrVal> AttrValueUtil::ToCppAttrValue(const cfg::AttrValue& cfg_attr_value) {
return MakeCppAttrValueFromProtoOrCfgAttrValue(proto_attr_value);
}

/*static*/Maybe<AttrVal> AttrValueUtil::ToProtoAttrValue(
const AttrVal& cpp_attr_value, AttrValue* attr_value) {
if (false) {
// clang-format off
#define MAKE_ENTRY(field, cpp_type, attr_type) \
} else if (dynamic_cast<const TypedAttrVal<cpp_type>*>(&cpp_attr_value) != nullptr) { \
const auto* ptr = dynamic_cast<const TypedAttrVal<cpp_type>*>(&cpp_attr_value); \
AttrValueAccessor<cpp_type>::Attr(ptr->val(), attr_value);
OF_PP_FOR_EACH_TUPLE(MAKE_ENTRY, ATTR_SEQ);
#undef MAKE_ENTRY
// clang-format on
} else {
OF_UNIMPLEMENTED();
}
}

} // namespace user_op

} // namespace oneflow
7 changes: 6 additions & 1 deletion oneflow/core/framework/attr_value_accessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ namespace user_op {
template<typename T>
struct AttrValueAccessor final {
static T Attr(const AttrValue&);
static T Attr(const cfg::AttrValue&);
static void Attr(const T&, AttrValue*);
};

Expand All @@ -34,7 +35,11 @@ namespace cfg {
class AttrValue;
}

Maybe<AttrVal> MakeCppAttrValByCfgAttrValue(const cfg::AttrValue&);
struct AttrValueUtil final {
static Maybe<AttrVal> ToCppAttrValue(const AttrValue& proto_attr_value);
static Maybe<AttrVal> ToCppAttrValue(const cfg::AttrValue& cfg_attr_value);
static Maybe<void> ToProtoAttrValue(const AttrVal& cpp_attr_value, AttrValue* attr_value);
};

} // namespace user_op

Expand Down
2 changes: 1 addition & 1 deletion oneflow/core/framework/attr_value_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ AttrValueMap::AttrValueMap(const MutableAttrValueMap& other) {

AttrValueMap::AttrValueMap(const MutableCfgAttrValueMap& other) {
for (const auto& pair : other) {
attrs_[pair.first] = CHECK_JUST(user_op::MakeCppAttrValByCfgAttrValue(*pair.second));
attrs_[pair.first] = CHECK_JUST(AttrValueUtil::ToCppAttrValue(*pair.second));
}
}

Expand Down