Skip to content

Commit 600864f

Browse files
committed
Fix errors importing triton
1 parent 9f4db58 commit 600864f

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

torchft/manager.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,24 @@
3939

4040
import torch
4141
from torch.distributed import ReduceOp, TCPStore
42+
from torch.distributed.distributed_c10d import AllreduceOptions, ReduceOp
4243

4344
from torchft._torchft import ManagerClient, ManagerServer
4445
from torchft.checkpointing import CheckpointTransport, HTTPTransport
45-
from torchft.collectives import allreduce_quantized
4646
from torchft.futures import future_timeout
4747

4848
if TYPE_CHECKING:
4949
from torchft.process_group import ProcessGroup
5050

51+
IS_TRITON_AVAILABLE = True
52+
try:
53+
# pyre-ignore[21]: Could not find a module corresponding to import `triton`
54+
import triton
55+
56+
from torchft.collectives import allreduce_quantized
57+
except ImportError:
58+
IS_TRITON_AVAILABLE = False
59+
5160
MANAGER_ADDR_KEY: str = "manager_addr"
5261
MANAGER_PORT_ENV: str = "TORCHFT_MANAGER_PORT"
5362
REPLICA_ID_KEY: str = "replica_id"
@@ -308,7 +317,7 @@ def allreduce(
308317
| torch.futures.Future[torch.Tensor]
309318
| torch.futures.Future[List[torch.Tensor]]
310319
] = None
311-
if should_quantize:
320+
if should_quantize and IS_TRITON_AVAILABLE:
312321
fut = allreduce_quantized([tensor], ReduceOp.AVG, self._pg)
313322
else:
314323
work = self._pg.allreduce([tensor], ReduceOp.SUM)

0 commit comments

Comments
 (0)