We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 4ec056b commit 8b2d68fCopy full SHA for 8b2d68f
nnvm/src/pass/gradient.cc
@@ -115,6 +115,8 @@ Graph Gradient(Graph src) {
115
}
116
std::vector<NodeEntry> input_grads = grad_fun_map[ptr->op()]
117
(mirror_map.size() == 0 ? ptr : mirror_map.at(ptr.get()), out_agg_grads);
118
+ CHECK_EQ((*rit)->inputs.size(), input_grads.size())
119
+ << "Gradient function not returning enough gradient";
120
auto git = input_grads.begin();
121
for (auto it = (*rit)->inputs.begin(); it != (*rit)->inputs.end(); ++it, ++git) {
122
output_grads[it->node.get()][it->index].grads.emplace_back(std::move(*git));
0 commit comments