Skip to content

Commit

Permalink
fix makenonlossgrad bug (#8508)
Browse files Browse the repository at this point in the history
  • Loading branch information
piiswrong authored Nov 2, 2017
1 parent 29a6a75 commit 9e0432a
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 10 deletions.
4 changes: 2 additions & 2 deletions src/operator/operator_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -412,8 +412,8 @@ inline std::vector<nnvm::NodeEntry> MakeZeroGradNodes(

// check whether all output grads are zero.
inline bool CheckGradAllZero(const std::vector<nnvm::NodeEntry>& ograds) {
const auto zero_op = nnvm::Op::Get("_zeros");
const auto zero_like_op = nnvm::Op::Get("zeros_like");
static const auto zero_op = nnvm::Op::Get("_zeros");
static const auto zero_like_op = nnvm::Op::Get("zeros_like");
if (!ograds.size()) return false;
for (const auto& grad : ograds) {
if (!grad.node) return false;
Expand Down
5 changes: 3 additions & 2 deletions src/operator/tensor/broadcast_reduce_op_index.cc
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,9 @@ Examples::
.set_attr<FCompute>("FCompute<cpu>", PickOpForward<cpu>)
.set_attr<nnvm::FGradient>("FGradient",
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
auto ret = MakeNonlossGradNode("_backward_pick", n, ograds,
{n->inputs[1]}, n->attrs.dict);
if (CheckGradAllZero(ograds)) return MakeZeroGradNodes(n, ograds);
auto ret = MakeGradNode("_backward_pick", n, {ograds[0], n->inputs[1]},
n->attrs.dict);
auto p = MakeNode("zeros_like", n->attrs.name + "_index_backward",
{n->inputs[1]}, nullptr, &n);
ret.emplace_back(nnvm::NodeEntry{p, 0, 0});
Expand Down
12 changes: 6 additions & 6 deletions src/operator/tensor/elemwise_unary_op_basic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -241,9 +241,9 @@ NNVM_REGISTER_OP(_identity_with_attr_like_rhs)
.set_attr<nnvm::FGradient>(
"FGradient", [](const nnvm::NodePtr& n,
const std::vector<nnvm::NodeEntry>& ograds) {
auto lhs = MakeNonlossGradNode(
"_backward_copy", n, ograds, {},
std::unordered_map<std::string, std::string>());
if (CheckGradAllZero(ograds)) return MakeZeroGradNodes(n, ograds);
auto lhs = MakeGradNode("_backward_copy", n, ograds,
std::unordered_map<std::string, std::string>());
auto ng = MakeNode("zeros_like", n->attrs.name + "_rhs_backward",
{n->inputs[1]}, nullptr, &n);
lhs.push_back(nnvm::NodeEntry{ng, 0, 0});
Expand Down Expand Up @@ -284,9 +284,9 @@ NNVM_REGISTER_OP(reshape_like)
.set_attr<nnvm::FGradient>(
"FGradient", [](const nnvm::NodePtr& n,
const std::vector<nnvm::NodeEntry>& ograds) {
auto lhs = MakeNonlossGradNode(
"_backward_copy", n, ograds, {},
std::unordered_map<std::string, std::string>());
if (CheckGradAllZero(ograds)) return MakeZeroGradNodes(n, ograds);
auto lhs = MakeGradNode("_backward_copy", n, ograds,
std::unordered_map<std::string, std::string>());
auto ng = MakeNode("zeros_like", n->attrs.name + "_rhs_backward",
{n->inputs[1]}, nullptr, &n);
lhs.push_back(nnvm::NodeEntry{ng, 0, 0});
Expand Down

0 comments on commit 9e0432a

Please sign in to comment.