@@ -48,29 +48,6 @@ namespace framework {
48
48
49
49
using Tensor = framework::Tensor;
50
50
51
- template <typename ClassType>
52
- void ExposeOperator (ClassType &m) {
53
- m.def (" infer_shape" , &ClassType::type::InferShape)
54
- .def (" run" , &ClassType::type::Run)
55
- .def (" type" ,
56
- [](const typename ClassType::type &op) -> std::string {
57
- return op.Type ();
58
- })
59
- .def (" outputs" ,
60
- [](const typename ClassType::type &op)
61
- -> std::map<std::string, std::vector<std::string>> {
62
- return op.Outputs ();
63
- })
64
- .def (" inputs" ,
65
- [](const typename ClassType::type &op) { return op.Inputs (); })
66
- .def (" __str__" , &ClassType::type::DebugString)
67
- .def (" no_intermediate_outputs" ,
68
- [](const typename ClassType::type &op) {
69
- return op.OutputVars (false );
70
- })
71
- .def (" support_gpu" , &ClassType::type::SupportGPU);
72
- }
73
-
74
51
static size_t UniqueIntegerGenerator () {
75
52
static std::atomic<size_t > generator;
76
53
return generator.fetch_add (1 );
@@ -207,75 +184,69 @@ All parameter, weight, gradient are variables in Paddle.
207
184
.def (py::init<>())
208
185
.def (" __str__" , string::to_string<const platform::CPUPlace &>);
209
186
210
- py::class_<OperatorBase, std::shared_ptr<OperatorBase>> operator_base (
211
- m, " Operator" );
212
-
213
- operator_base.def_static (" create" , [](py::bytes protobin) {
214
- OpDesc desc;
215
- PADDLE_ENFORCE (desc.ParsePartialFromString (protobin),
216
- " Cannot parse user input to OpDesc" );
217
- PADDLE_ENFORCE (desc.IsInitialized (),
218
- " User OpDesc is not initialized, reason %s" ,
219
- desc.InitializationErrorString ());
220
- return OpRegistry::CreateOp (desc);
221
- });
222
-
223
- operator_base.def (" backward" ,
224
- [](const OperatorBase &forwardOp,
225
- const std::unordered_set<std::string> &no_grad_vars) {
226
- return Backward (forwardOp, no_grad_vars);
227
- });
228
-
229
- ExposeOperator (operator_base);
230
-
231
- py::class_<operators::NetOp, std::shared_ptr<operators::NetOp>> net (m, " Net" );
232
-
233
- net.def_static (" create" ,
234
- []() -> std::shared_ptr<operators::NetOp> {
235
- auto retv = std::make_shared<operators::NetOp>();
236
- retv->SetType (" plain_net" );
237
- return retv;
238
- })
239
- .def (" add_op" , &operators::NetOp::AddOp)
240
- .def (" add_op" ,
241
- [](operators::NetOp &self,
242
- const std::shared_ptr<operators::NetOp> &net) -> void {
243
- self.AddOp (std::static_pointer_cast<OperatorBase>(net));
244
- })
245
- .def (" add_op" ,
246
- [](operators::NetOp &self,
247
- const std::shared_ptr<operators::RecurrentOp> &rnn) -> void {
248
- self.AddOp (std::static_pointer_cast<OperatorBase>(rnn));
187
+ py::class_<OperatorBase>(m, " Operator" )
188
+ .def_static (" create" ,
189
+ [](py::bytes protobin) {
190
+ OpDesc desc;
191
+ PADDLE_ENFORCE (desc.ParsePartialFromString (protobin),
192
+ " Cannot parse user input to OpDesc" );
193
+ PADDLE_ENFORCE (desc.IsInitialized (),
194
+ " User OpDesc is not initialized, reason %s" ,
195
+ desc.InitializationErrorString ());
196
+ return OpRegistry::CreateOp (desc);
197
+ })
198
+ .def (" backward" ,
199
+ [](const OperatorBase &forwardOp,
200
+ const std::unordered_set<std::string> &no_grad_vars) {
201
+ return Backward (forwardOp, no_grad_vars).release ();
249
202
})
203
+ .def (" infer_shape" , &OperatorBase::InferShape)
204
+ .def (" run" , &OperatorBase::Run)
205
+ .def (" type" ,
206
+ [](const OperatorBase &op) -> std::string { return op.Type (); })
207
+ .def (" outputs" ,
208
+ [](const OperatorBase &op)
209
+ -> std::map<std::string, std::vector<std::string>> {
210
+ return op.Outputs ();
211
+ })
212
+ .def (" inputs" , [](const OperatorBase &op) { return op.Inputs (); })
213
+ .def (" __str__" , &OperatorBase::DebugString)
214
+ .def (" no_intermediate_outputs" ,
215
+ [](const OperatorBase &op) { return op.OutputVars (false ); })
216
+ .def (" support_gpu" , &OperatorBase::SupportGPU);
217
+
218
+ py::class_<operators::NetOp, OperatorBase>(m, " Net" )
219
+ .def_static (" create" ,
220
+ []() -> operators::NetOp * {
221
+ auto *retv = new operators::NetOp;
222
+ retv->SetType (" plain_net" );
223
+ return retv;
224
+ })
225
+ .def (" add_op" , [](operators::NetOp &self,
226
+ const OperatorBase &op) { self.AddOp (op); })
250
227
.def (" complete_add_op" , &operators::NetOp::CompleteAddOp)
251
228
.def (" complete_add_op" , [](std::shared_ptr<operators::NetOp> &self) {
252
229
self->CompleteAddOp ();
253
230
});
254
231
255
- ExposeOperator (net);
256
-
257
232
// recurrent_op
258
- py::class_<operators::RecurrentOp, std::shared_ptr<operators::RecurrentOp>>
259
- rnn (m, " RecurrentOp" );
260
-
261
- rnn.def_static (
262
- " create" ,
263
- [](py::bytes protobin) -> std::shared_ptr<operators::RecurrentOp> {
264
- OpDesc desc;
265
- PADDLE_ENFORCE (desc.ParsePartialFromString (protobin),
266
- " Cannot parse user input to OpDesc" );
267
- PADDLE_ENFORCE (desc.IsInitialized (),
268
- " User OpDesc is not initialized, reason %s" ,
269
- desc.InitializationErrorString ());
270
- auto rnn_op = OpRegistry::CreateOp (desc);
271
- return std::dynamic_pointer_cast<operators::RecurrentOp>(rnn_op);
272
- })
273
- .def (" set_stepnet" ,
274
- [](operators::RecurrentOp &self,
275
- const std::shared_ptr<operators::NetOp> &net) -> void {
276
- self.set_stepnet (net);
277
- });
278
- ExposeOperator (rnn);
233
+ py::class_<operators::RecurrentOp, OperatorBase>(m, " RecurrentOp" )
234
+ .def_static (
235
+ " create" ,
236
+ [](py::bytes protobin) -> operators::RecurrentOp * {
237
+ OpDesc desc;
238
+ PADDLE_ENFORCE (desc.ParsePartialFromString (protobin),
239
+ " Cannot parse user input to OpDesc" );
240
+ PADDLE_ENFORCE (desc.IsInitialized (),
241
+ " User OpDesc is not initialized, reason %s" ,
242
+ desc.InitializationErrorString ());
243
+ auto rnn_op = OpRegistry::CreateOp (desc);
244
+ return static_cast <operators::RecurrentOp *>(rnn_op.release ());
245
+ })
246
+ .def (" set_stepnet" , [](operators::RecurrentOp &self,
247
+ const operators::NetOp &net) -> void {
248
+ self.set_stepnet (net.Clone ());
249
+ });
279
250
280
251
m.def (" unique_integer" , UniqueIntegerGenerator);
281
252
0 commit comments