File tree Expand file tree Collapse file tree 1 file changed +19
-1
lines changed Expand file tree Collapse file tree 1 file changed +19
-1
lines changed Original file line number Diff line number Diff line change 39
39
40
40
import torch
41
41
from torch .distributed import ReduceOp , TCPStore
42
+ from torch .distributed .distributed_c10d import AllreduceOptions , ReduceOp
42
43
43
44
from torchft ._torchft import ManagerClient , ManagerServer
44
45
from torchft .checkpointing import CheckpointTransport , HTTPTransport
45
- from torchft .collectives import allreduce_quantized
46
46
from torchft .futures import future_timeout
47
47
48
48
if TYPE_CHECKING :
49
49
from torchft .process_group import ProcessGroup
50
50
51
+ try :
52
+ # pyre-ignore[21]: Could not find a module corresponding to import `triton`
53
+ import triton
54
+
55
+ from torchft .collectives import allreduce_quantized
56
+ except ImportError :
57
+ from torch import cuda
58
+
59
+ def allreduce_quantized (
60
+ tensors : list [torch .Tensor ],
61
+ opts : AllreduceOptions | ReduceOp ,
62
+ process_group : "ProcessGroup" ,
63
+ sync_stream : cuda .Stream | None = None ,
64
+ ) -> torch .futures .Future [List [torch .Tensor ]]:
65
+ work = process_group .allreduce (tensors , opts )
66
+ return work .get_future ()
67
+
68
+
51
69
MANAGER_ADDR_KEY : str = "manager_addr"
52
70
MANAGER_PORT_ENV : str = "TORCHFT_MANAGER_PORT"
53
71
REPLICA_ID_KEY : str = "replica_id"
You can’t perform that action at this time.
0 commit comments