File tree Expand file tree Collapse file tree 1 file changed +4
-0
lines changed
src/transformers/integrations Expand file tree Collapse file tree 1 file changed +4
-0
lines changed Original file line number Diff line number Diff line change @@ -52,6 +52,7 @@ def initialize_tensor_parallelism(tp_plan, tp_size=None):
5252
5353 # Detect the accelerator on the machine. If no accelerator is available, it returns CPU.
5454 device_type = torch ._C ._get_accelerator ().type
55+ current_device = getattr (torch , device_type )
5556 if not torch .distributed .is_initialized ():
5657 try :
5758 rank = int (os .environ ["RANK" ])
@@ -73,6 +74,9 @@ def initialize_tensor_parallelism(tp_plan, tp_size=None):
7374 "We tried to initialize torch.distributed for you, but it failed. Make "
7475 "sure you init torch distributed in your script to use `tp_plan='auto'`."
7576 ) from e
77+
78+ if device_type != "cpu" :
79+ current_device .set_device (int (os .environ ["LOCAL_RANK" ]))
7680 index = current_device .current_device () if device_type != "cpu" else None
7781 tp_device = torch .device (device_type , index )
7882
You can’t perform that action at this time.
0 commit comments