Skip to content

Commit 8b2d68f

Browse files
committed
Strict gradient boundary check (#44)
1 parent 4ec056b commit 8b2d68f

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

nnvm/src/pass/gradient.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,8 @@ Graph Gradient(Graph src) {
115115
}
116116
std::vector<NodeEntry> input_grads = grad_fun_map[ptr->op()]
117117
(mirror_map.size() == 0 ? ptr : mirror_map.at(ptr.get()), out_agg_grads);
118+
CHECK_EQ((*rit)->inputs.size(), input_grads.size())
119+
<< "Gradient function not returning enough gradient";
118120
auto git = input_grads.begin();
119121
for (auto it = (*rit)->inputs.begin(); it != (*rit)->inputs.end(); ++it, ++git) {
120122
output_grads[it->node.get()][it->index].grads.emplace_back(std::move(*git));

0 commit comments

Comments
 (0)