@@ -11,7 +11,7 @@ namespace torch { namespace autograd {
11
11
auto Engine::compute_dependencies (function_queue queue, ready_queue_type& ready) -> dependencies_type {
12
12
// First, search the graph and find all stochastic functions. Append them to the queue.
13
13
std::unordered_set<Function*> seen;
14
- function_queue search_queue (queue. begin (), queue. end () );
14
+ function_queue search_queue (queue);
15
15
while (search_queue.size () > 0 ) {
16
16
auto fn = search_queue.back (); search_queue.pop_back ();
17
17
for (auto & prev_fn_pair : fn->previous_functions ) {
@@ -33,6 +33,8 @@ auto Engine::compute_dependencies(function_queue queue, ready_queue_type& ready)
33
33
// to expand functions that don't require grad.
34
34
dependencies_type dependencies;
35
35
seen.clear ();
36
+ // Just to make sure that they will never be added to the queue again
37
+ seen.insert (queue.begin (), queue.end ());
36
38
while (queue.size () > 0 ) {
37
39
auto fn = std::move (queue.back ()); queue.pop_back ();
38
40
// This is needed only to filter out backward roots that don't require grad
@@ -42,9 +44,8 @@ auto Engine::compute_dependencies(function_queue queue, ready_queue_type& ready)
42
44
if (!prev_ptr) continue ;
43
45
if (dynamic_cast <Variable*>(prev_ptr)) continue ;
44
46
if (!prev_ptr->requires_grad ) continue ;
45
- if (!prev_ptr->is_stochastic ) {
46
- dependencies[prev_ptr] += 1 ;
47
- }
47
+ if (prev_ptr->is_stochastic ) continue ; // Stochastic nodes were in the queue already
48
+ dependencies[prev_ptr] += 1 ;
48
49
if (seen.count (prev_ptr) == 0 ) {
49
50
seen.insert (prev_ptr);
50
51
queue.push_back (prev_ptr);
0 commit comments