@@ -124,13 +124,7 @@ def setup_module(self, module: Module) -> DistributedDataParallel:
124
124
"""Wraps the model into a :class:`~torch.nn.parallel.distributed.DistributedDataParallel` module."""
125
125
device_ids = self ._determine_ddp_device_ids ()
126
126
# https://pytorch.org/docs/stable/notes/cuda.html#id5
127
- ctx = (
128
- getattr (torch , f"{ self .root_device .type .split (':' )[0 ]} " ).stream (
129
- getattr (torch , f"{ self .root_device .type .split (':' )[0 ]} " ).Stream ()
130
- )
131
- if device_ids is not None
132
- else nullcontext ()
133
- )
127
+ ctx = self ._create_stream_context (device_ids = device_ids )
134
128
with ctx :
135
129
return DistributedDataParallel (module = module , device_ids = device_ids , ** self ._ddp_kwargs )
136
130
@@ -234,6 +228,18 @@ def _set_world_ranks(self) -> None:
234
228
def _determine_ddp_device_ids (self ) -> Optional [list [int ]]:
235
229
return None if self .root_device .type == "cpu" else [self .root_device .index ]
236
230
231
+ def _create_stream_context (self , device_ids = None ):
232
+ """Create a stream context for the current device, if supported."""
233
+
234
+ torch_lib = getattr (torch , self .root_device .type )
235
+ # Check if the device type supports streams and has the necessary attributes.
236
+ if hasattr (torch_lib , "Stream" ) and hasattr (torch_lib , "stream" ) and device_ids is not None :
237
+ stream = torch_lib .Stream ()
238
+ ctx = torch_lib .stream (stream )
239
+ else :
240
+ ctx = nullcontext ()
241
+ return ctx
242
+
237
243
238
244
class _DDPBackwardSyncControl (_BackwardSyncControl ):
239
245
@override
0 commit comments