Skip to content

Commit 03a4c02

Browse files
authored
Fix tp error when torch distributed is already initialized (#38294)
fix tp error
1 parent dcaf47d commit 03a4c02

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

src/transformers/integrations/tensor_parallel.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def initialize_tensor_parallelism(tp_plan, tp_size=None):
5252

5353
# Detect the accelerator on the machine. If no accelerator is available, it returns CPU.
5454
device_type = torch._C._get_accelerator().type
55+
current_device = getattr(torch, device_type)
5556
if not torch.distributed.is_initialized():
5657
try:
5758
rank = int(os.environ["RANK"])
@@ -73,6 +74,9 @@ def initialize_tensor_parallelism(tp_plan, tp_size=None):
7374
"We tried to initialize torch.distributed for you, but it failed. Make "
7475
"sure you init torch distributed in your script to use `tp_plan='auto'`."
7576
) from e
77+
78+
if device_type != "cpu":
79+
current_device.set_device(int(os.environ["LOCAL_RANK"]))
7680
index = current_device.current_device() if device_type != "cpu" else None
7781
tp_device = torch.device(device_type, index)
7882

0 commit comments

Comments
 (0)