Skip to content

Commit 95f4d3d

Browse files
committed
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into trt-IPluginV2Ext
2 parents df49070 + 1f28968 commit 95f4d3d

File tree

14 files changed

+549
-149
lines changed

14 files changed

+549
-149
lines changed

AUTHORS.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,3 +78,6 @@
7878
| zhaopu7 | Pu Zhao |
7979
| zhouxiao-coder | Xiao Zhou |
8080
| Zrachel | Rui-Qing Zhang |
81+
| jeng1220 | Bai-Cheng(Ryan) Jeng (NVIDIA) |
82+
| mingxu1067 | Ming Huang (NVIDIA) |
83+
| zlsh80826 | Reese Wang (NVIDIA) |

paddle/fluid/imperative/partial_grad_engine.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ static void GetGraphInfoBetweenTargets(
7373
std::unordered_map<OpBase *, size_t> *op_deps_ptr,
7474
std::unordered_set<VariableWrapper *> *related_grad_vars_ptr,
7575
const std::unordered_set<VariableWrapper *> &no_grad_var_grad) {
76+
VLOG(10) << "prune graph starts";
7677
/**
7778
* Step 1. Find the candidate startup grad ops, prepared for following BFS.
7879
*/
@@ -117,6 +118,8 @@ static void GetGraphInfoBetweenTargets(
117118
auto *op = op_node_pair.first;
118119
auto *node = op_node_pair.second;
119120

121+
VLOG(10) << "Visit node " << node << " , visit op " << op->Type();
122+
120123
for (auto &output_pair : op->GetOutsMap()) {
121124
if (!output_pair.second.IsGrad()) {
122125
VLOG(10) << "WARNING: " << op->Type() << " outputs a forward var";
@@ -135,6 +138,7 @@ static void GetGraphInfoBetweenTargets(
135138

136139
for (auto &pending_node : node->GradPendingNodes()) {
137140
if (visited.count(pending_node.get()) == 0) {
141+
visited.insert(pending_node.get());
138142
for (auto &pending_op : *pending_node) {
139143
preceding_ops[&pending_op].insert(op);
140144
q.emplace(&pending_op, pending_node.get());
@@ -143,6 +147,8 @@ static void GetGraphInfoBetweenTargets(
143147
}
144148
}
145149

150+
VLOG(10) << "Found endpoint op ends";
151+
146152
/**
147153
* Step 3. Based on the found input_target_grads, BFS the graph in reverse
148154
* order. `target_vars` would record all grad vars in the graph, and
@@ -246,6 +252,8 @@ static void GetGraphInfoBetweenTargets(
246252
}
247253
}
248254

255+
VLOG(10) << "Found startup op ends";
256+
249257
/**
250258
* Step 4. Prune output_targets which is not the input of startup_ops
251259
*/

paddle/fluid/operators/slice_op_npu.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,16 @@ namespace operators {
2525

2626
using Tensor = framework::Tensor;
2727

28-
void UpdateAttr(const framework::DDim in_dims, const std::vector<int> axes,
28+
void UpdateAttr(const framework::DDim& in_dims, const std::vector<int> axes,
2929
const std::vector<int> starts, const std::vector<int> ends,
3030
std::vector<int>* offsets, std::vector<int>* size) {
3131
int cnt = 0;
3232
for (int i = 0; i < in_dims.size(); ++i) {
3333
int start = 0;
3434
int end = in_dims[i];
35-
int axis = axes[cnt];
36-
35+
// NOTE(zhiqiu): Becareful that cnt may > axes.size() and result in
36+
// overflow.
37+
int axis = cnt < static_cast<int>(axes.size()) ? axes[cnt] : -1;
3738
if (axis == i) {
3839
start = starts[cnt];
3940
if (start < 0) {
@@ -63,10 +64,10 @@ class SliceNPUKernel : public framework::OpKernel<T> {
6364
auto axes = ctx.Attr<std::vector<int>>("axes");
6465
auto starts = ctx.Attr<std::vector<int>>("starts");
6566
auto ends = ctx.Attr<std::vector<int>>("ends");
67+
const auto& in_dims = input->dims();
6668

6769
out->mutable_data<T>(ctx.GetPlace());
6870

69-
auto in_dims = input->dims();
7071
std::vector<int> offsets(in_dims.size());
7172
std::vector<int> size(in_dims.size());
7273

@@ -93,8 +94,7 @@ class SliceGradNPUKernel : public framework::OpKernel<T> {
9394
auto axes = ctx.Attr<std::vector<int>>("axes");
9495
auto starts = ctx.Attr<std::vector<int>>("starts");
9596
auto ends = ctx.Attr<std::vector<int>>("ends");
96-
97-
auto in_dims = input->dims();
97+
const auto& in_dims = input->dims();
9898
int rank = in_dims.size();
9999

100100
std::vector<int> offsets(rank);

0 commit comments

Comments
 (0)