File tree Expand file tree Collapse file tree 3 files changed +25
-12
lines changed Expand file tree Collapse file tree 3 files changed +25
-12
lines changed Original file line number Diff line number Diff line change @@ -5869,6 +5869,8 @@ py_library(
5869
5869
deps = [
5870
5870
":framework_for_generated_wrappers" ,
5871
5871
":nccl_ops_gen" ,
5872
+ "//tensorflow/python/eager:context" ,
5873
+ "//tensorflow/python/eager:def_function" ,
5872
5874
],
5873
5875
)
5874
5876
Original file line number Diff line number Diff line change 35
35
"CollectiveReduce" ,
36
36
"CollectiveBcastSend" ,
37
37
"CollectiveBcastRecv" ,
38
+ "NcclAllReduce" ,
38
39
]
39
40
40
41
Original file line number Diff line number Diff line change 19
19
20
20
import threading
21
21
22
+ from tensorflow .python .eager import context
23
+ from tensorflow .python .eager import def_function
22
24
from tensorflow .python .framework import device
23
25
from tensorflow .python .framework import ops
24
26
from tensorflow .python .ops import gen_nccl_ops
@@ -211,19 +213,27 @@ def _apply_all_reduce(reduction, tensors):
211
213
raise ValueError ('Must pass >0 tensors to all reduce operations' )
212
214
213
215
shared_name = _get_shared_name ()
214
- res = []
215
216
216
- for t in tensors :
217
- _check_device (t )
218
- with ops .device (t .device ):
219
- res .append (
220
- gen_nccl_ops .nccl_all_reduce (
221
- input = t ,
222
- reduction = reduction ,
223
- num_devices = len (tensors ),
224
- shared_name = shared_name ))
225
-
226
- return res
217
+ def _all_reduce ():
218
+ """Call nccl allreduce."""
219
+ res = []
220
+ for t in tensors :
221
+ _check_device (t )
222
+ with ops .device (t .device ):
223
+ res .append (
224
+ gen_nccl_ops .nccl_all_reduce (
225
+ input = t ,
226
+ reduction = reduction ,
227
+ num_devices = len (tensors ),
228
+ shared_name = shared_name ))
229
+ return res
230
+
231
+ if context .executing_eagerly ():
232
+ # Nccl ops will block unless they are executed concurrently such as in a
233
+ # graph or a defun.
234
+ return def_function .function (_all_reduce )()
235
+ else :
236
+ return _all_reduce ()
227
237
228
238
229
239
def _apply_reduce (reduction , tensors ):
You can’t perform that action at this time.
0 commit comments