Skip to content

Commit cb34ff1

Browse files
committed
Fix errors importing triton
1 parent b6822a6 commit cb34ff1

File tree

1 file changed

+18
-1
lines changed

1 file changed

+18
-1
lines changed

torchft/manager.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,32 @@
3939

4040
import torch
4141
from torch.distributed import ReduceOp, TCPStore
42+
from torch.distributed.distributed_c10d import AllreduceOptions, ReduceOp
4243

4344
from torchft._torchft import ManagerClient, ManagerServer
4445
from torchft.checkpointing import CheckpointTransport, HTTPTransport
45-
from torchft.collectives import allreduce_quantized
4646
from torchft.futures import future_timeout
4747

4848
if TYPE_CHECKING:
4949
from torchft.process_group import ProcessGroup
5050

51+
try:
52+
# pyre-ignore[21]: Could not find a module corresponding to import `triton`
53+
import triton
54+
55+
from torchft.collectives import allreduce_quantized
56+
except ImportError:
57+
from torch import cuda
58+
59+
def allreduce_quantized(
60+
tensors: list[torch.Tensor],
61+
opts: AllreduceOptions | ReduceOp,
62+
process_group: "ProcessGroup",
63+
sync_stream: cuda.Stream | None = None,
64+
) -> torch.futures.Future[List[torch.Tensor]]:
65+
return process_group.allreduce(tensors, opts)
66+
67+
5168
MANAGER_ADDR_KEY: str = "manager_addr"
5269
MANAGER_PORT_ENV: str = "TORCHFT_MANAGER_PORT"
5370
REPLICA_ID_KEY: str = "replica_id"

0 commit comments

Comments
 (0)