Skip to content

Commit

Permalink
Fix Engine::compute_dependencies
Browse files Browse the repository at this point in the history
  • Loading branch information
apaszke authored and soumith committed Feb 17, 2017
1 parent dd844f7 commit aaf41c6
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 4 deletions.
12 changes: 12 additions & 0 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,6 +691,18 @@ def test_stochastic(self):

self.assertGreater(x.grad.data.abs().sum(), 0)

def test_stochastic_require_grad(self):
# This tests a DSD function sequence (D=deterministic, S=stochastic),
# where all functions require grad.
x = Variable(torch.randn(2, 10), requires_grad=True)
y = Variable(torch.randn(2, 10), requires_grad=True)
z = torch.normal(x + 2, 2)
o = z + y
z.reinforce(torch.randn(2, 10))
o.sum().backward()
self.assertEqual(y.grad.data, torch.ones(2, 10))
self.assertGreater(x.grad.data.abs().sum(), 0)

def test_stochastic_sequence(self):
x = Variable(torch.rand(10).clamp_(0, 1), requires_grad=True)
b = x.bernoulli()
Expand Down
9 changes: 5 additions & 4 deletions torch/csrc/autograd/engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ namespace torch { namespace autograd {
auto Engine::compute_dependencies(function_queue queue, ready_queue_type& ready) -> dependencies_type {
// First, search the graph and find all stochastic functions. Append them to the queue.
std::unordered_set<Function*> seen;
function_queue search_queue(queue.begin(), queue.end());
function_queue search_queue(queue);
while (search_queue.size() > 0) {
auto fn = search_queue.back(); search_queue.pop_back();
for (auto& prev_fn_pair : fn->previous_functions) {
Expand All @@ -33,6 +33,8 @@ auto Engine::compute_dependencies(function_queue queue, ready_queue_type& ready)
// to expand functions that don't require grad.
dependencies_type dependencies;
seen.clear();
// Just to make sure that they will never be added to the queue again
seen.insert(queue.begin(), queue.end());
while (queue.size() > 0) {
auto fn = std::move(queue.back()); queue.pop_back();
// This is needed only to filter out backward roots that don't require grad
Expand All @@ -42,9 +44,8 @@ auto Engine::compute_dependencies(function_queue queue, ready_queue_type& ready)
if (!prev_ptr) continue;
if (dynamic_cast<Variable*>(prev_ptr)) continue;
if (!prev_ptr->requires_grad) continue;
if (!prev_ptr->is_stochastic) {
dependencies[prev_ptr] += 1;
}
if (prev_ptr->is_stochastic) continue; // Stochastic nodes were in the queue already
dependencies[prev_ptr] += 1;
if (seen.count(prev_ptr) == 0) {
seen.insert(prev_ptr);
queue.push_back(prev_ptr);
Expand Down

0 comments on commit aaf41c6

Please sign in to comment.