Skip to content

Commit 4087467

Browse files
Yuefeng Zhoutensorflower-gardener
Yuefeng Zhou
authored andcommitted
Make nccl work in eager mode: wrap nccl ops in a defun; remove control dependencies on NcclAllReduce
PiperOrigin-RevId: 228801431
1 parent d2490b9 commit 4087467

File tree

3 files changed

+25
-12
lines changed

3 files changed

+25
-12
lines changed

tensorflow/python/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5869,6 +5869,8 @@ py_library(
58695869
deps = [
58705870
":framework_for_generated_wrappers",
58715871
":nccl_ops_gen",
5872+
"//tensorflow/python/eager:context",
5873+
"//tensorflow/python/eager:def_function",
58725874
],
58735875
)
58745876

tensorflow/python/framework/auto_control_deps.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
"CollectiveReduce",
3636
"CollectiveBcastSend",
3737
"CollectiveBcastRecv",
38+
"NcclAllReduce",
3839
]
3940

4041

tensorflow/python/ops/nccl_ops.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919

2020
import threading
2121

22+
from tensorflow.python.eager import context
23+
from tensorflow.python.eager import def_function
2224
from tensorflow.python.framework import device
2325
from tensorflow.python.framework import ops
2426
from tensorflow.python.ops import gen_nccl_ops
@@ -211,19 +213,27 @@ def _apply_all_reduce(reduction, tensors):
211213
raise ValueError('Must pass >0 tensors to all reduce operations')
212214

213215
shared_name = _get_shared_name()
214-
res = []
215216

216-
for t in tensors:
217-
_check_device(t)
218-
with ops.device(t.device):
219-
res.append(
220-
gen_nccl_ops.nccl_all_reduce(
221-
input=t,
222-
reduction=reduction,
223-
num_devices=len(tensors),
224-
shared_name=shared_name))
225-
226-
return res
217+
def _all_reduce():
218+
"""Call nccl allreduce."""
219+
res = []
220+
for t in tensors:
221+
_check_device(t)
222+
with ops.device(t.device):
223+
res.append(
224+
gen_nccl_ops.nccl_all_reduce(
225+
input=t,
226+
reduction=reduction,
227+
num_devices=len(tensors),
228+
shared_name=shared_name))
229+
return res
230+
231+
if context.executing_eagerly():
232+
# Nccl ops will block unless they are executed concurrently such as in a
233+
# graph or a defun.
234+
return def_function.function(_all_reduce)()
235+
else:
236+
return _all_reduce()
227237

228238

229239
def _apply_reduce(reduction, tensors):

0 commit comments

Comments
 (0)