Skip to content

Commit 1808f88

Browse files
authored
Merge pull request PaddlePaddle#3538 from reyoung/feature/remove_shared_ptr
Feature/remove shared ptr
2 parents 812a64c + 16d0215 commit 1808f88

File tree

11 files changed

+131
-160
lines changed

11 files changed

+131
-160
lines changed

paddle/framework/backward.cc

+20-22
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
#include "paddle/framework/backward.h"
1616

1717
#include <list>
18+
#include <memory>
19+
1820
#include "paddle/framework/op_registry.h"
1921
#include "paddle/operators/net_op.h"
2022
#include "paddle/operators/recurrent_op.h"
@@ -43,11 +45,11 @@ static bool AllInSet(
4345
return all_in_set;
4446
}
4547

46-
static std::shared_ptr<OperatorBase> NOP() {
47-
auto net_op = std::make_shared<operators::NetOp>();
48+
static std::unique_ptr<OperatorBase> NOP() {
49+
auto net_op = new operators::NetOp();
4850
net_op->SetType("@NOP@");
4951
net_op->CompleteAddOp();
50-
return net_op;
52+
return std::unique_ptr<OperatorBase>(net_op);
5153
}
5254

5355
// Get backward operator from a forward operator, a recursive implementation.
@@ -62,11 +64,7 @@ static std::shared_ptr<OperatorBase> NOP() {
6264
// operator, in a complex situation, it maybe a NetOp.
6365
//
6466
// See Backward.h for details
65-
static std::shared_ptr<OperatorBase> BackwardRecursive(
66-
const OperatorBase& forwardOp,
67-
std::unordered_set<std::string>& no_grad_names, size_t& uniq_id);
68-
69-
std::shared_ptr<OperatorBase> BackwardRecursive(
67+
static std::unique_ptr<OperatorBase> BackwardRecursive(
7068
const OperatorBase& forwardOp,
7169
std::unordered_set<std::string>& no_grad_names, size_t& uniq_id) {
7270
// If all input gradients of forwarding operator do not need to calculate,
@@ -91,7 +89,7 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
9189
}
9290

9391
// Returned gradient network
94-
auto net = std::make_shared<operators::NetOp>();
92+
auto net = std::unique_ptr<operators::NetOp>(new operators::NetOp());
9593

9694
if (forwardOp.IsNetOp()) {
9795
// Because forwardOp is a net op, it can static_cast.
@@ -105,14 +103,14 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
105103
// reversely travel forwardNet and collect all duplicate outputs.
106104
for (auto it = forwardNet.ops_.rbegin(); it != forwardNet.ops_.rend();
107105
++it, ++local_op_id) {
108-
auto fwd = *it;
106+
auto& fwd = *it;
109107
auto bwd = BackwardRecursive(*fwd, no_grad_names, uniq_id);
110-
net->AddOp(bwd);
111108
ForEachVarName(bwd->Outputs(),
112109
[&dup_output_ops, local_op_id](const std::string& out) {
113110
dup_output_ops[out].emplace_back(local_op_id);
114111
return false;
115112
});
113+
net->AddOp(std::move(bwd));
116114
}
117115
// Get unique ID for this method.
118116
auto uid = uniq_id++;
@@ -122,7 +120,7 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
122120
// to handle this case. For each duplicate output, rename it to an alias
123121
// (original name with a offset), append an `add` op for its operator,
124122
// and finally sum all the alias variable to the final output variable y.
125-
using Pos = std::pair<size_t, std::shared_ptr<OperatorBase>>;
123+
using Pos = std::pair<size_t, std::unique_ptr<OperatorBase>>;
126124
std::list<Pos> insert_position;
127125
for (auto& dup_output_op : dup_output_ops) {
128126
const std::string& name = dup_output_op.first;
@@ -150,13 +148,13 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
150148
[](const Pos& l, const Pos& r) { return l.first > r.first; });
151149

152150
for (auto& pos : insert_position) {
153-
net->InsertOp(pos.first + 1, pos.second);
151+
net->InsertOp(pos.first + 1, std::move(pos.second));
154152
}
155153
} else {
156-
std::shared_ptr<OperatorBase> grad_op = OpRegistry::CreateGradOp(forwardOp);
154+
std::unique_ptr<OperatorBase> grad_op(OpRegistry::CreateGradOp(forwardOp));
157155

158-
ForEachVarName(grad_op->Inputs(), [&no_grad_names, &net,
159-
grad_op](const std::string& grad_input) {
156+
ForEachVarName(grad_op->Inputs(), [&no_grad_names, &net, &grad_op](
157+
const std::string& grad_input) {
160158
if (no_grad_names.count(grad_input)) {
161159
// +1 for \0
162160
std::string prefix = grad_input.substr(
@@ -190,23 +188,23 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
190188
const auto& stepnet_op =
191189
*static_cast<const OperatorBase*>(&rnnop.stepnet());
192190
// create stepnet's gradient op
193-
auto grad_stepnet = BackwardRecursive(stepnet_op, no_grad_names, uniq_id);
194191
rnn_grad_op->set_stepnet(
195-
std::static_pointer_cast<operators::NetOp>(grad_stepnet));
192+
BackwardRecursive(stepnet_op, no_grad_names, uniq_id));
196193
}
197194

198195
if (net->ops_.empty()) { // Current no aux op is added to network
199196
return grad_op;
200197
}
201-
net->AddOp(grad_op);
198+
net->AddOp(std::move(grad_op));
202199
}
203200
net->SetType("@GENERATED_BACKWARD@");
204201
net->CompleteAddOp();
205-
return net;
206-
} // namespace framework
202+
return std::unique_ptr<OperatorBase>(
203+
static_cast<OperatorBase*>(net.release()));
204+
}
207205

208206
// See header for comments
209-
std::shared_ptr<OperatorBase> Backward(
207+
std::unique_ptr<OperatorBase> Backward(
210208
const OperatorBase& forwardOp,
211209
const std::unordered_set<std::string>& no_grad_vars) {
212210
std::unordered_set<std::string> no_grad_names;

paddle/framework/backward.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ namespace framework {
2020

2121
// Create the backward operator from a forward operator.
2222
// TODO(yuyang18): Add more API reference comment.
23-
extern std::shared_ptr<OperatorBase> Backward(
23+
extern std::unique_ptr<OperatorBase> Backward(
2424
const OperatorBase& forwardOp,
2525
const std::unordered_set<std::string>& no_grad_vars);
2626
} // namespace framework

paddle/framework/backward_test.cc

+1-2
Original file line numberDiff line numberDiff line change
@@ -180,8 +180,7 @@ TEST(Backward, simple_op_not_need_grad) {
180180
auto no_input_gop = f::Backward(*fwd, {"x", "b"});
181181
ASSERT_NE(no_input_gop, nullptr);
182182
ASSERT_TRUE(no_input_gop->IsNetOp());
183-
ASSERT_EQ(0UL,
184-
std::static_pointer_cast<ops::NetOp>(no_input_gop)->ops_.size());
183+
ASSERT_EQ(0UL, static_cast<ops::NetOp *>(no_input_gop.get())->ops_.size());
185184
}
186185

187186
TEST(Backward, net_fc_backward_normal) {

paddle/framework/op_registry.cc

+5-6
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ limitations under the License. */
1919
namespace paddle {
2020
namespace framework {
2121

22-
std::shared_ptr<OperatorBase> OpRegistry::CreateOp(const std::string& type,
22+
std::unique_ptr<OperatorBase> OpRegistry::CreateOp(const std::string& type,
2323
const VarNameMap& inputs,
2424
const VarNameMap& outputs,
2525
AttributeMap attrs) {
@@ -28,10 +28,10 @@ std::shared_ptr<OperatorBase> OpRegistry::CreateOp(const std::string& type,
2828
"Operator '%s' has not been registered.", type);
2929
it->second.checker_->Check(attrs);
3030
auto op = it->second.creator_(type, inputs, outputs, attrs);
31-
return std::shared_ptr<OperatorBase>(op);
31+
return std::unique_ptr<OperatorBase>(op);
3232
}
3333

34-
std::shared_ptr<OperatorBase> OpRegistry::CreateOp(const OpDesc& op_desc) {
34+
std::unique_ptr<OperatorBase> OpRegistry::CreateOp(const OpDesc& op_desc) {
3535
VarNameMap inputs = ConvertOpDescVarsToVarNameMap(op_desc.inputs());
3636
VarNameMap outputs = ConvertOpDescVarsToVarNameMap(op_desc.outputs());
3737
AttributeMap attrs;
@@ -55,10 +55,9 @@ OperatorBase::VarNameMap OpRegistry::ConvertOpDescVarsToVarNameMap(
5555
return ret_val;
5656
}
5757

58-
std::shared_ptr<OperatorBase> OpRegistry::CreateGradOp(const OperatorBase& op) {
58+
std::unique_ptr<OperatorBase> OpRegistry::CreateGradOp(const OperatorBase& op) {
5959
PADDLE_ENFORCE(!op.IsNetOp(), "Use framework::Backward to get backward ops");
60-
std::shared_ptr<OperatorBase> grad_op(BuildGradOp(&op));
61-
return grad_op;
60+
return std::unique_ptr<OperatorBase>(BuildGradOp(&op));
6261
}
6362

6463
} // namespace framework

paddle/framework/op_registry.h

+3-3
Original file line numberDiff line numberDiff line change
@@ -77,17 +77,17 @@ class OpRegistry {
7777
}
7878
}
7979

80-
static std::shared_ptr<OperatorBase> CreateOp(const std::string& type,
80+
static std::unique_ptr<OperatorBase> CreateOp(const std::string& type,
8181
const VarNameMap& inputs,
8282
const VarNameMap& outputs,
8383
AttributeMap attrs);
8484

85-
static std::shared_ptr<OperatorBase> CreateOp(const OpDesc& op_desc);
85+
static std::unique_ptr<OperatorBase> CreateOp(const OpDesc& op_desc);
8686

8787
static VarNameMap ConvertOpDescVarsToVarNameMap(
8888
const google::protobuf::RepeatedPtrField<OpDesc::Var>& op_desc_vars);
8989

90-
static std::shared_ptr<OperatorBase> CreateGradOp(const OperatorBase& op);
90+
static std::unique_ptr<OperatorBase> CreateGradOp(const OperatorBase& op);
9191

9292
static std::unordered_map<std::string, const OpInfo>& op_info_map() {
9393
static std::unordered_map<std::string, const OpInfo> op_info_map_;

paddle/framework/op_registry_test.cc

+2-4
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,7 @@ TEST(OpRegistry, CreateOp) {
7676
attr->set_type(paddle::framework::AttrType::FLOAT);
7777
attr->set_f(scale);
7878

79-
std::shared_ptr<paddle::framework::OperatorBase> op =
80-
paddle::framework::OpRegistry::CreateOp(op_desc);
79+
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
8180
paddle::framework::Scope scope;
8281
paddle::platform::CPUDeviceContext dev_ctx;
8382
op->Run(scope, dev_ctx);
@@ -118,8 +117,7 @@ TEST(OpRegistry, DefaultValue) {
118117

119118
ASSERT_TRUE(op_desc.IsInitialized());
120119

121-
std::shared_ptr<paddle::framework::OperatorBase> op =
122-
paddle::framework::OpRegistry::CreateOp(op_desc);
120+
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
123121
paddle::framework::Scope scope;
124122
paddle::platform::CPUDeviceContext dev_ctx;
125123
op->Run(scope, dev_ctx);

paddle/framework/pybind.cc

+56-85
Original file line numberDiff line numberDiff line change
@@ -48,29 +48,6 @@ namespace framework {
4848

4949
using Tensor = framework::Tensor;
5050

51-
template <typename ClassType>
52-
void ExposeOperator(ClassType &m) {
53-
m.def("infer_shape", &ClassType::type::InferShape)
54-
.def("run", &ClassType::type::Run)
55-
.def("type",
56-
[](const typename ClassType::type &op) -> std::string {
57-
return op.Type();
58-
})
59-
.def("outputs",
60-
[](const typename ClassType::type &op)
61-
-> std::map<std::string, std::vector<std::string>> {
62-
return op.Outputs();
63-
})
64-
.def("inputs",
65-
[](const typename ClassType::type &op) { return op.Inputs(); })
66-
.def("__str__", &ClassType::type::DebugString)
67-
.def("no_intermediate_outputs",
68-
[](const typename ClassType::type &op) {
69-
return op.OutputVars(false);
70-
})
71-
.def("support_gpu", &ClassType::type::SupportGPU);
72-
}
73-
7451
static size_t UniqueIntegerGenerator() {
7552
static std::atomic<size_t> generator;
7653
return generator.fetch_add(1);
@@ -207,75 +184,69 @@ All parameter, weight, gradient are variables in Paddle.
207184
.def(py::init<>())
208185
.def("__str__", string::to_string<const platform::CPUPlace &>);
209186

210-
py::class_<OperatorBase, std::shared_ptr<OperatorBase>> operator_base(
211-
m, "Operator");
212-
213-
operator_base.def_static("create", [](py::bytes protobin) {
214-
OpDesc desc;
215-
PADDLE_ENFORCE(desc.ParsePartialFromString(protobin),
216-
"Cannot parse user input to OpDesc");
217-
PADDLE_ENFORCE(desc.IsInitialized(),
218-
"User OpDesc is not initialized, reason %s",
219-
desc.InitializationErrorString());
220-
return OpRegistry::CreateOp(desc);
221-
});
222-
223-
operator_base.def("backward",
224-
[](const OperatorBase &forwardOp,
225-
const std::unordered_set<std::string> &no_grad_vars) {
226-
return Backward(forwardOp, no_grad_vars);
227-
});
228-
229-
ExposeOperator(operator_base);
230-
231-
py::class_<operators::NetOp, std::shared_ptr<operators::NetOp>> net(m, "Net");
232-
233-
net.def_static("create",
234-
[]() -> std::shared_ptr<operators::NetOp> {
235-
auto retv = std::make_shared<operators::NetOp>();
236-
retv->SetType("plain_net");
237-
return retv;
238-
})
239-
.def("add_op", &operators::NetOp::AddOp)
240-
.def("add_op",
241-
[](operators::NetOp &self,
242-
const std::shared_ptr<operators::NetOp> &net) -> void {
243-
self.AddOp(std::static_pointer_cast<OperatorBase>(net));
244-
})
245-
.def("add_op",
246-
[](operators::NetOp &self,
247-
const std::shared_ptr<operators::RecurrentOp> &rnn) -> void {
248-
self.AddOp(std::static_pointer_cast<OperatorBase>(rnn));
187+
py::class_<OperatorBase>(m, "Operator")
188+
.def_static("create",
189+
[](py::bytes protobin) {
190+
OpDesc desc;
191+
PADDLE_ENFORCE(desc.ParsePartialFromString(protobin),
192+
"Cannot parse user input to OpDesc");
193+
PADDLE_ENFORCE(desc.IsInitialized(),
194+
"User OpDesc is not initialized, reason %s",
195+
desc.InitializationErrorString());
196+
return OpRegistry::CreateOp(desc);
197+
})
198+
.def("backward",
199+
[](const OperatorBase &forwardOp,
200+
const std::unordered_set<std::string> &no_grad_vars) {
201+
return Backward(forwardOp, no_grad_vars).release();
249202
})
203+
.def("infer_shape", &OperatorBase::InferShape)
204+
.def("run", &OperatorBase::Run)
205+
.def("type",
206+
[](const OperatorBase &op) -> std::string { return op.Type(); })
207+
.def("outputs",
208+
[](const OperatorBase &op)
209+
-> std::map<std::string, std::vector<std::string>> {
210+
return op.Outputs();
211+
})
212+
.def("inputs", [](const OperatorBase &op) { return op.Inputs(); })
213+
.def("__str__", &OperatorBase::DebugString)
214+
.def("no_intermediate_outputs",
215+
[](const OperatorBase &op) { return op.OutputVars(false); })
216+
.def("support_gpu", &OperatorBase::SupportGPU);
217+
218+
py::class_<operators::NetOp, OperatorBase>(m, "Net")
219+
.def_static("create",
220+
[]() -> operators::NetOp * {
221+
auto *retv = new operators::NetOp;
222+
retv->SetType("plain_net");
223+
return retv;
224+
})
225+
.def("add_op", [](operators::NetOp &self,
226+
const OperatorBase &op) { self.AddOp(op); })
250227
.def("complete_add_op", &operators::NetOp::CompleteAddOp)
251228
.def("complete_add_op", [](std::shared_ptr<operators::NetOp> &self) {
252229
self->CompleteAddOp();
253230
});
254231

255-
ExposeOperator(net);
256-
257232
// recurrent_op
258-
py::class_<operators::RecurrentOp, std::shared_ptr<operators::RecurrentOp>>
259-
rnn(m, "RecurrentOp");
260-
261-
rnn.def_static(
262-
"create",
263-
[](py::bytes protobin) -> std::shared_ptr<operators::RecurrentOp> {
264-
OpDesc desc;
265-
PADDLE_ENFORCE(desc.ParsePartialFromString(protobin),
266-
"Cannot parse user input to OpDesc");
267-
PADDLE_ENFORCE(desc.IsInitialized(),
268-
"User OpDesc is not initialized, reason %s",
269-
desc.InitializationErrorString());
270-
auto rnn_op = OpRegistry::CreateOp(desc);
271-
return std::dynamic_pointer_cast<operators::RecurrentOp>(rnn_op);
272-
})
273-
.def("set_stepnet",
274-
[](operators::RecurrentOp &self,
275-
const std::shared_ptr<operators::NetOp> &net) -> void {
276-
self.set_stepnet(net);
277-
});
278-
ExposeOperator(rnn);
233+
py::class_<operators::RecurrentOp, OperatorBase>(m, "RecurrentOp")
234+
.def_static(
235+
"create",
236+
[](py::bytes protobin) -> operators::RecurrentOp * {
237+
OpDesc desc;
238+
PADDLE_ENFORCE(desc.ParsePartialFromString(protobin),
239+
"Cannot parse user input to OpDesc");
240+
PADDLE_ENFORCE(desc.IsInitialized(),
241+
"User OpDesc is not initialized, reason %s",
242+
desc.InitializationErrorString());
243+
auto rnn_op = OpRegistry::CreateOp(desc);
244+
return static_cast<operators::RecurrentOp *>(rnn_op.release());
245+
})
246+
.def("set_stepnet", [](operators::RecurrentOp &self,
247+
const operators::NetOp &net) -> void {
248+
self.set_stepnet(net.Clone());
249+
});
279250

280251
m.def("unique_integer", UniqueIntegerGenerator);
281252

0 commit comments

Comments
 (0)