|
20 | 20 | import queue |
21 | 21 | import threading |
22 | 22 | from datetime import timedelta |
23 | | -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type, Union |
| 23 | +from typing import Dict, List, Optional, Tuple, Type, TYPE_CHECKING, Union |
24 | 24 |
|
25 | 25 | import torch |
26 | 26 | import torch.distributed as dist |
|
31 | 31 | from torch.distributed import ( |
32 | 32 | BroadcastOptions, |
33 | 33 | DeviceMesh, |
| 34 | + get_rank, |
| 35 | + init_device_mesh, |
34 | 36 | PrefixStore, |
35 | 37 | ProcessGroup as BaseProcessGroup, |
36 | 38 | ProcessGroupGloo as BaseProcessGroupGloo, |
37 | 39 | ProcessGroupNCCL as BaseProcessGroupNCCL, |
38 | 40 | Store, |
39 | 41 | TCPStore, |
40 | | - get_rank, |
41 | | - init_device_mesh, |
42 | 42 | ) |
43 | | -from torch.distributed.distributed_c10d import Work, _world |
| 43 | +from torch.distributed.distributed_c10d import _world, Work |
44 | 44 | from torch.futures import Future |
45 | 45 |
|
46 | 46 | if TYPE_CHECKING: |
@@ -132,10 +132,20 @@ def allgather( |
132 | 132 | input_tensor: List[torch.Tensor], |
133 | 133 | opts: object, |
134 | 134 | ) -> Work: |
| 135 | + """ |
| 136 | + Gathers tensors from the whole group in a list. |
| 137 | +
|
| 138 | + See torch.distributed.all_gather for more details. |
| 139 | + """ |
135 | 140 | raise NotImplementedError("not implemented") |
136 | 141 |
|
137 | 142 | # pyre-fixme[14]: inconsistent override |
138 | 143 | def broadcast(self, tensor_list: List[torch.Tensor], opts: object) -> Work: |
| 144 | + """ |
| 145 | + Broadcasts the tensor to the whole group. |
| 146 | +
|
| 147 | + See torch.distributed.broadcast for more details. |
| 148 | + """ |
139 | 149 | raise NotImplementedError("not implemented") |
140 | 150 |
|
141 | 151 | def broadcast_one(self, tensor: torch.Tensor, root: int) -> Work: |
|
0 commit comments