Skip to content

Commit e720bc9

Browse files
committed
process_group: fix docs with torch==2.6.0
1 parent 866873a commit e720bc9

File tree

1 file changed

+14
-4
lines changed

1 file changed

+14
-4
lines changed

torchft/process_group.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import queue
2121
import threading
2222
from datetime import timedelta
23-
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type, Union
23+
from typing import Dict, List, Optional, Tuple, Type, TYPE_CHECKING, Union
2424

2525
import torch
2626
import torch.distributed as dist
@@ -31,16 +31,16 @@
3131
from torch.distributed import (
3232
BroadcastOptions,
3333
DeviceMesh,
34+
get_rank,
35+
init_device_mesh,
3436
PrefixStore,
3537
ProcessGroup as BaseProcessGroup,
3638
ProcessGroupGloo as BaseProcessGroupGloo,
3739
ProcessGroupNCCL as BaseProcessGroupNCCL,
3840
Store,
3941
TCPStore,
40-
get_rank,
41-
init_device_mesh,
4242
)
43-
from torch.distributed.distributed_c10d import Work, _world
43+
from torch.distributed.distributed_c10d import _world, Work
4444
from torch.futures import Future
4545

4646
if TYPE_CHECKING:
@@ -132,10 +132,20 @@ def allgather(
132132
input_tensor: List[torch.Tensor],
133133
opts: object,
134134
) -> Work:
135+
"""
136+
Gathers tensors from the whole group in a list.
137+
138+
See torch.distributed.all_gather for more details.
139+
"""
135140
raise NotImplementedError("not implemented")
136141

137142
# pyre-fixme[14]: inconsistent override
138143
def broadcast(self, tensor_list: List[torch.Tensor], opts: object) -> Work:
144+
"""
145+
Broadcasts the tensor to the whole group.
146+
147+
See torch.distributed.broadcast for more details.
148+
"""
139149
raise NotImplementedError("not implemented")
140150

141151
def broadcast_one(self, tensor: torch.Tensor, root: int) -> Work:

0 commit comments

Comments
 (0)