From aaf41c61a657e188c6021d577de30ae041560a32 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Fri, 17 Feb 2017 04:40:22 -0800 Subject: [PATCH] Fix Engine::compute_dependencies --- test/test_autograd.py | 12 ++++++++++++ torch/csrc/autograd/engine.cpp | 9 +++++---- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/test/test_autograd.py b/test/test_autograd.py index ac05e8b96ba2a..84d19ae9a1315 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -691,6 +691,18 @@ def test_stochastic(self): self.assertGreater(x.grad.data.abs().sum(), 0) + def test_stochastic_require_grad(self): + # This tests a DSD function sequence (D=deterministic, S=stochastic), + # where all functions require grad. + x = Variable(torch.randn(2, 10), requires_grad=True) + y = Variable(torch.randn(2, 10), requires_grad=True) + z = torch.normal(x + 2, 2) + o = z + y + z.reinforce(torch.randn(2, 10)) + o.sum().backward() + self.assertEqual(y.grad.data, torch.ones(2, 10)) + self.assertGreater(x.grad.data.abs().sum(), 0) + def test_stochastic_sequence(self): x = Variable(torch.rand(10).clamp_(0, 1), requires_grad=True) b = x.bernoulli() diff --git a/torch/csrc/autograd/engine.cpp b/torch/csrc/autograd/engine.cpp index 8421d2afb2def..b1ba4c36b0783 100644 --- a/torch/csrc/autograd/engine.cpp +++ b/torch/csrc/autograd/engine.cpp @@ -11,7 +11,7 @@ namespace torch { namespace autograd { auto Engine::compute_dependencies(function_queue queue, ready_queue_type& ready) -> dependencies_type { // First, search the graph and find all stochastic functions. Append them to the queue. std::unordered_set seen; - function_queue search_queue(queue.begin(), queue.end()); + function_queue search_queue(queue); while (search_queue.size() > 0) { auto fn = search_queue.back(); search_queue.pop_back(); for (auto& prev_fn_pair : fn->previous_functions) { @@ -33,6 +33,8 @@ auto Engine::compute_dependencies(function_queue queue, ready_queue_type& ready) // to expand functions that don't require grad. dependencies_type dependencies; seen.clear(); + // Just to make sure that they will never be added to the queue again + seen.insert(queue.begin(), queue.end()); while (queue.size() > 0) { auto fn = std::move(queue.back()); queue.pop_back(); // This is needed only to filter out backward roots that don't require grad @@ -42,9 +44,8 @@ auto Engine::compute_dependencies(function_queue queue, ready_queue_type& ready) if (!prev_ptr) continue; if (dynamic_cast(prev_ptr)) continue; if (!prev_ptr->requires_grad) continue; - if (!prev_ptr->is_stochastic) { - dependencies[prev_ptr] += 1; - } + if (prev_ptr->is_stochastic) continue; // Stochastic nodes were in the queue already + dependencies[prev_ptr] += 1; if (seen.count(prev_ptr) == 0) { seen.insert(prev_ptr); queue.push_back(prev_ptr);