|
39 | 39 |
|
40 | 40 | import torch
|
41 | 41 | from torch.distributed import ReduceOp, TCPStore
|
| 42 | +from torch.distributed.distributed_c10d import AllreduceOptions, ReduceOp |
42 | 43 |
|
43 | 44 | from torchft._torchft import ManagerClient, ManagerServer
|
44 | 45 | from torchft.checkpointing import CheckpointTransport, HTTPTransport
|
45 |
| -from torchft.collectives import allreduce_quantized |
46 | 46 | from torchft.futures import future_timeout
|
47 | 47 |
|
48 | 48 | if TYPE_CHECKING:
|
49 | 49 | from torchft.process_group import ProcessGroup
|
50 | 50 |
|
| 51 | +try: |
| 52 | + # pyre-ignore[21]: Could not find a module corresponding to import `triton` |
| 53 | + import triton |
| 54 | + |
| 55 | + from torchft.collectives import allreduce_quantized |
| 56 | +except ImportError: |
| 57 | + from torch import cuda |
| 58 | + |
| 59 | + def allreduce_quantized( |
| 60 | + tensors: list[torch.Tensor], |
| 61 | + opts: AllreduceOptions | ReduceOp, |
| 62 | + process_group: "ProcessGroup", |
| 63 | + sync_stream: cuda.Stream | None = None, |
| 64 | + ) -> torch.futures.Future[None]: |
| 65 | + work = process_group.allreduce(tensors, opts) |
| 66 | + fut = work.get_future() |
| 67 | + |
| 68 | + def callback(fut: torch.futures.Future[List[torch.Tensor]]) -> None: |
| 69 | + return None |
| 70 | + |
| 71 | + return fut.then(callback) |
| 72 | + |
| 73 | + |
51 | 74 | MANAGER_ADDR_KEY: str = "manager_addr"
|
52 | 75 | MANAGER_PORT_ENV: str = "TORCHFT_MANAGER_PORT"
|
53 | 76 | REPLICA_ID_KEY: str = "replica_id"
|
|
0 commit comments