Skip to content

Commit aaf41c6

Browse files
apaszkesoumith
authored andcommitted
Fix Engine::compute_dependencies
1 parent dd844f7 commit aaf41c6

File tree

2 files changed

+17
-4
lines changed

2 files changed

+17
-4
lines changed

test/test_autograd.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -691,6 +691,18 @@ def test_stochastic(self):
691691

692692
self.assertGreater(x.grad.data.abs().sum(), 0)
693693

694+
def test_stochastic_require_grad(self):
695+
# This tests a DSD function sequence (D=deterministic, S=stochastic),
696+
# where all functions require grad.
697+
x = Variable(torch.randn(2, 10), requires_grad=True)
698+
y = Variable(torch.randn(2, 10), requires_grad=True)
699+
z = torch.normal(x + 2, 2)
700+
o = z + y
701+
z.reinforce(torch.randn(2, 10))
702+
o.sum().backward()
703+
self.assertEqual(y.grad.data, torch.ones(2, 10))
704+
self.assertGreater(x.grad.data.abs().sum(), 0)
705+
694706
def test_stochastic_sequence(self):
695707
x = Variable(torch.rand(10).clamp_(0, 1), requires_grad=True)
696708
b = x.bernoulli()

torch/csrc/autograd/engine.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ namespace torch { namespace autograd {
1111
auto Engine::compute_dependencies(function_queue queue, ready_queue_type& ready) -> dependencies_type {
1212
// First, search the graph and find all stochastic functions. Append them to the queue.
1313
std::unordered_set<Function*> seen;
14-
function_queue search_queue(queue.begin(), queue.end());
14+
function_queue search_queue(queue);
1515
while (search_queue.size() > 0) {
1616
auto fn = search_queue.back(); search_queue.pop_back();
1717
for (auto& prev_fn_pair : fn->previous_functions) {
@@ -33,6 +33,8 @@ auto Engine::compute_dependencies(function_queue queue, ready_queue_type& ready)
3333
// to expand functions that don't require grad.
3434
dependencies_type dependencies;
3535
seen.clear();
36+
// Just to make sure that they will never be added to the queue again
37+
seen.insert(queue.begin(), queue.end());
3638
while (queue.size() > 0) {
3739
auto fn = std::move(queue.back()); queue.pop_back();
3840
// This is needed only to filter out backward roots that don't require grad
@@ -42,9 +44,8 @@ auto Engine::compute_dependencies(function_queue queue, ready_queue_type& ready)
4244
if (!prev_ptr) continue;
4345
if (dynamic_cast<Variable*>(prev_ptr)) continue;
4446
if (!prev_ptr->requires_grad) continue;
45-
if (!prev_ptr->is_stochastic) {
46-
dependencies[prev_ptr] += 1;
47-
}
47+
if (prev_ptr->is_stochastic) continue; // Stochastic nodes were in the queue already
48+
dependencies[prev_ptr] += 1;
4849
if (seen.count(prev_ptr) == 0) {
4950
seen.insert(prev_ptr);
5051
queue.push_back(prev_ptr);

0 commit comments

Comments
 (0)