Skip to content

Commit 23b9783

Browse files
MingMingShangTianShixiaowei02
authored andcommitted
add trace op_register_version and fix version bug; test=op_version (PaddlePaddle#30000)
* add trace op_register_version and fix defaulf bug; test=op_version * add trace op_register_version; test=op_version * add trace op_register_version; test=op_version * add trace op_register_version; test=op_version * fix missing the template bug of vector; test=op_version
1 parent 04eb711 commit 23b9783

File tree

1 file changed

+21
-2
lines changed

1 file changed

+21
-2
lines changed

paddle/fluid/operators/trace_op.cc

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
// limitations under the License.
1414

1515
#include "paddle/fluid/operators/trace_op.h"
16+
#include "paddle/fluid/framework/op_version_registry.h"
1617

1718
namespace paddle {
1819
namespace operators {
@@ -88,13 +89,13 @@ class TraceOpMaker : public framework::OpProtoAndCheckerMaker {
8889
R"DOC((int, default 0), the first axis of the 2-D planes from which the diagonals should be taken.
8990
Can be either positive or negative. Default: 0.
9091
)DOC")
91-
.SetDefault(-2);
92+
.SetDefault(0);
9293
AddAttr<int>(
9394
"axis2",
9495
R"DOC((int, default 1), the second axis of the 2-D planes from which the diagonals should be taken.
9596
Can be either positive or negative. Default: 1.
9697
)DOC")
97-
.SetDefault(-1);
98+
.SetDefault(1);
9899
AddComment(R"DOC(
99100
Trace Operator.
100101
Return the sum along diagonals of the input tensor.
@@ -177,3 +178,21 @@ REGISTER_OP_CPU_KERNEL(
177178
paddle::platform::complex64>,
178179
ops::TraceGradKernel<paddle::platform::CPUDeviceContext,
179180
paddle::platform::complex128>);
181+
182+
/* ========================== register checkpoint ===========================*/
183+
REGISTER_OP_VERSION(trace)
184+
.AddCheckpoint(
185+
R"ROC(Upgrade trace add a new attribute [axis2])ROC",
186+
paddle::framework::compatible::OpVersionDesc()
187+
.NewAttr("axis1",
188+
"The added attribute 'axis1' is not yet registered.",
189+
std::vector<float>{0.0f})
190+
.NewAttr("axis2",
191+
"The added attribute 'axis2' is not yet registered.",
192+
std::vector<float>{1.0f})
193+
.DeleteAttr("dim1",
194+
"The attribute 'dim1' is not recommend according to "
195+
"the specification 2.0.")
196+
.DeleteAttr("dim2",
197+
"The attribute 'dim2' is not recommend according to "
198+
"the specification 2.0."));

0 commit comments

Comments
 (0)