Skip to content

Commit 3748b6d

Browse files
csarofeensoumith
authored andcommitted
Data parallel fix for pytorch#1857 (pytorch#1880)
* Data parallel fix for pytorch#1857 searches recursively for variable in input * parallel_apply.py lint
1 parent b3589b0 commit 3748b6d

File tree

1 file changed

+18
-3
lines changed

1 file changed

+18
-3
lines changed

torch/nn/parallel/parallel_apply.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,23 @@
33
from torch.autograd import Variable
44

55

6+
def get_a_var(obj):
7+
if isinstance(obj, Variable):
8+
return obj
9+
10+
if isinstance(obj, list) or isinstance(obj, tuple):
11+
results = map(get_a_var, obj)
12+
for result in results:
13+
if isinstance(result, Variable):
14+
return result
15+
if isinstance(obj, dict):
16+
results = map(get_a_var, obj.items())
17+
for result in results:
18+
if isinstance(result, Variable):
19+
return result
20+
return None
21+
22+
623
def parallel_apply(modules, inputs, kwargs_tup=None):
724
assert len(modules) == len(inputs)
825
if kwargs_tup:
@@ -17,9 +34,7 @@ def parallel_apply(modules, inputs, kwargs_tup=None):
1734
results = {}
1835

1936
def _worker(i, module, input, kwargs, results, lock):
20-
var_input = input
21-
while not isinstance(var_input, Variable):
22-
var_input = var_input[0]
37+
var_input = get_a_var(input)
2338
try:
2439
with torch.cuda.device_of(var_input):
2540
output = module(*input, **kwargs)

0 commit comments

Comments
 (0)