Skip to content

Commit b2147d6

Browse files
committed
fix a bug in op_version_registry, test=develop, test=op_version (PaddlePaddle#29994)
1 parent 23b9783 commit b2147d6

File tree

5 files changed

+39
-31
lines changed

5 files changed

+39
-31
lines changed

paddle/fluid/framework/op_version_registry.cc

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -18,29 +18,6 @@ namespace paddle {
1818
namespace framework {
1919
namespace compatible {
2020

21-
namespace {
22-
template <OpUpdateType type__, typename InfoType>
23-
OpUpdate<InfoType, type__>* new_update(InfoType&& info) {
24-
return new OpUpdate<InfoType, type__>(info);
25-
}
26-
}
27-
28-
OpVersionDesc&& OpVersionDesc::ModifyAttr(const std::string& name,
29-
const std::string& remark,
30-
const OpAttrVariantT& default_value) {
31-
infos_.emplace_back(new_update<OpUpdateType::kModifyAttr>(
32-
OpAttrInfo(name, remark, default_value)));
33-
return std::move(*this);
34-
}
35-
36-
OpVersionDesc&& OpVersionDesc::NewAttr(const std::string& name,
37-
const std::string& remark,
38-
const OpAttrVariantT& default_value) {
39-
infos_.emplace_back(new_update<OpUpdateType::kNewAttr>(
40-
OpAttrInfo(name, remark, default_value)));
41-
return std::move(*this);
42-
}
43-
4421
OpVersionDesc&& OpVersionDesc::NewInput(const std::string& name,
4522
const std::string& remark) {
4623
infos_.emplace_back(

paddle/fluid/framework/op_version_registry.h

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,13 +118,44 @@ class OpUpdate : public OpUpdateBase {
118118
OpUpdateType type_;
119119
};
120120

121+
template <OpUpdateType type__, typename InfoType>
122+
OpUpdate<InfoType, type__>* new_update(InfoType&& info) {
123+
return new OpUpdate<InfoType, type__>(info);
124+
}
125+
126+
template <typename T>
127+
OpAttrVariantT op_attr_wrapper(const T& val) {
128+
return OpAttrVariantT{val};
129+
}
130+
131+
template <int N>
132+
OpAttrVariantT op_attr_wrapper(const char (&val)[N]) {
133+
PADDLE_ENFORCE_EQ(
134+
val[N - 1], 0,
135+
platform::errors::InvalidArgument(
136+
"The argument of operator register %c is illegal.", val[N - 1]));
137+
return OpAttrVariantT{std::string{val}};
138+
}
139+
121140
class OpVersionDesc {
122141
public:
123142
/* Compatibility upgrade */
143+
template <typename T>
124144
OpVersionDesc&& ModifyAttr(const std::string& name, const std::string& remark,
125-
const OpAttrVariantT& default_value);
145+
const T& default_value) {
146+
infos_.emplace_back(new_update<OpUpdateType::kModifyAttr>(
147+
OpAttrInfo(name, remark, op_attr_wrapper(default_value))));
148+
return std::move(*this);
149+
}
150+
151+
template <typename T>
126152
OpVersionDesc&& NewAttr(const std::string& name, const std::string& remark,
127-
const OpAttrVariantT& default_value);
153+
const T& default_value) {
154+
infos_.emplace_back(new_update<OpUpdateType::kNewAttr>(
155+
OpAttrInfo(name, remark, op_attr_wrapper(default_value))));
156+
return std::move(*this);
157+
}
158+
128159
OpVersionDesc&& NewInput(const std::string& name, const std::string& remark);
129160
OpVersionDesc&& NewOutput(const std::string& name, const std::string& remark);
130161
OpVersionDesc&& BugfixWithBehaviorChanged(const std::string& remark);

paddle/fluid/operators/conv_transpose_op.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -578,7 +578,7 @@ REGISTER_OP_VERSION(conv_transpose)
578578
"output_padding",
579579
"In order to add additional size to one side of each dimension "
580580
"in the output",
581-
{}));
581+
std::vector<int>{}));
582582

583583
REGISTER_OP_VERSION(conv2d_transpose)
584584
.AddCheckpoint(
@@ -589,7 +589,7 @@ REGISTER_OP_VERSION(conv2d_transpose)
589589
"output_padding",
590590
"In order to add additional size to one side of each dimension "
591591
"in the output",
592-
{}));
592+
std::vector<int>{}));
593593

594594
REGISTER_OP_VERSION(conv3d_transpose)
595595
.AddCheckpoint(
@@ -600,7 +600,7 @@ REGISTER_OP_VERSION(conv3d_transpose)
600600
"output_padding",
601601
"In order to add additional size to one side of each dimension "
602602
"in the output",
603-
{}));
603+
std::vector<int>{}));
604604

605605
REGISTER_OP_VERSION(depthwise_conv2d_transpose)
606606
.AddCheckpoint(
@@ -611,4 +611,4 @@ REGISTER_OP_VERSION(depthwise_conv2d_transpose)
611611
"output_padding",
612612
"In order to add additional size to one side of each dimension "
613613
"in the output",
614-
{}));
614+
std::vector<int>{}));

paddle/fluid/operators/fused/fusion_gru_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -489,4 +489,4 @@ REGISTER_OP_VERSION(fusion_gru)
489489
"Scale_weights",
490490
"The added attribute 'Scale_weights' is not yet "
491491
"registered.",
492-
{1.0f}));
492+
std::vector<float>{1.0f}));

paddle/fluid/operators/unique_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ REGISTER_OP_VERSION(unique)
184184
.NewAttr("axis",
185185
"The axis to apply unique. If None, the input will be "
186186
"flattened.",
187-
{})
187+
std::vector<int>{})
188188
.NewAttr("is_sorted",
189189
"If True, the unique elements of X are in ascending order."
190190
"Otherwise, the unique elements are not sorted.",

0 commit comments

Comments
 (0)