@@ -118,13 +118,44 @@ class OpUpdate : public OpUpdateBase {
118118 OpUpdateType type_;
119119};
120120
121+ template <OpUpdateType type__, typename InfoType>
122+ OpUpdate<InfoType, type__>* new_update (InfoType&& info) {
123+ return new OpUpdate<InfoType, type__>(info);
124+ }
125+
126+ template <typename T>
127+ OpAttrVariantT op_attr_wrapper (const T& val) {
128+ return OpAttrVariantT{val};
129+ }
130+
131+ template <int N>
132+ OpAttrVariantT op_attr_wrapper (const char (&val)[N]) {
133+ PADDLE_ENFORCE_EQ (
134+ val[N - 1 ], 0 ,
135+ platform::errors::InvalidArgument (
136+ " The argument of operator register %c is illegal." , val[N - 1 ]));
137+ return OpAttrVariantT{std::string{val}};
138+ }
139+
121140class OpVersionDesc {
122141 public:
123142 /* Compatibility upgrade */
143+ template <typename T>
124144 OpVersionDesc&& ModifyAttr(const std::string& name, const std::string& remark,
125- const OpAttrVariantT& default_value);
145+ const T& default_value) {
146+ infos_.emplace_back (new_update<OpUpdateType::kModifyAttr >(
147+ OpAttrInfo (name, remark, op_attr_wrapper (default_value))));
148+ return std::move (*this );
149+ }
150+
151+ template <typename T>
126152 OpVersionDesc&& NewAttr(const std::string& name, const std::string& remark,
127- const OpAttrVariantT& default_value);
153+ const T& default_value) {
154+ infos_.emplace_back (new_update<OpUpdateType::kNewAttr >(
155+ OpAttrInfo (name, remark, op_attr_wrapper (default_value))));
156+ return std::move (*this );
157+ }
158+
128159 OpVersionDesc&& NewInput(const std::string& name, const std::string& remark);
129160 OpVersionDesc&& NewOutput(const std::string& name, const std::string& remark);
130161 OpVersionDesc&& BugfixWithBehaviorChanged(const std::string& remark);
0 commit comments