5151
5252import torch
5353import torch .distributed as dist
54+ import torchcomms
5455from torch .distributed import ReduceOp , TCPStore
5556from torch .distributed .distributed_c10d import AllreduceOptions , ReduceOp , Work
5657
5758from torchft ._torchft import ManagerClient , ManagerServer
5859from torchft .checkpointing import CheckpointTransport , HTTPTransport
5960from torchft .checkpointing ._rwlock import RWLock
6061from torchft .futures import future_timeout
62+ from torchft .torchcomms import TorchComm
6163from torchft .utils import get_stream_context , synchronize
6264from torchft .work import _DummyWork
6365
@@ -163,7 +165,7 @@ class Manager:
163165
164166 def __init__ (
165167 self ,
166- pg : "ProcessGroup" ,
168+ pg : Union [ "ProcessGroup" , TorchComm ] ,
167169 load_state_dict : Optional [Callable [[T ], None ]],
168170 state_dict : Optional [Callable [[], T ]],
169171 min_replica_size : int ,
@@ -188,6 +190,7 @@ def __init__(
188190 ) -> None :
189191 """
190192 Args:
193+ pg: process group or torchcomms wrapper to use for communication.
191194 load_state_dict: function to load the state dict when recovering
192195 state_dict: function to save the state dict with recovering
193196 min_replica_size: minimum number of replicas on each step
@@ -456,7 +459,9 @@ def allreduce(
456459 try :
457460 # Run the allreduce async and save the work object so we can wait on
458461 # it later.
462+ # TODO: Support quantization with torchcomms
459463 if should_quantize and IS_TRITON_AVAILABLE :
464+ assert isinstance (self ._pg , ProcessGroup )
460465 work = allreduce_quantized (
461466 [tensor ],
462467 pg_reduce_op ,
@@ -465,24 +470,36 @@ def allreduce(
465470 torch .accelerator .current_stream (),
466471 )
467472 else :
468- opts = AllreduceOptions ()
469- opts .reduceOp = pg_reduce_op
470- work = self ._pg .allreduce ([tensor ], opts )
473+ # Check if we're using torchcomms or ProcessGroup
474+ if isinstance (self ._pg , TorchComm ):
475+ # Convert PyTorch ReduceOp to torchcomms ReduceOp
476+ if pg_reduce_op == ReduceOp .SUM :
477+ tc_op = torchcomms .ReduceOp .SUM
478+ elif pg_reduce_op == ReduceOp .AVG :
479+ tc_op = torchcomms .ReduceOp .AVG
480+ else :
481+ raise AssertionError ("unsupported reduce op" )
482+
483+ work = self ._pg .allreduce (tensor , tc_op )
484+ else :
485+ opts = AllreduceOptions ()
486+ opts .reduceOp = pg_reduce_op
487+ work = self ._pg .allreduce ([tensor ], opts )
471488
472489 # schedule grad normalization as a continuation
473490 # on the Future
474491 @torch .profiler .record_function ("torchft::manager::allreduce::callback" )
475492 def callback (
476- fut : torch .futures .Future [torch .Tensor ],
493+ fut : torch .futures .Future [list [ torch .Tensor ] ],
477494 ) -> torch .Tensor :
478495 nonlocal tensor
479496 if reduce_op == ReduceOp .AVG :
480497 tensor /= num_participants
481498 return tensor
482499
483- managed_work = _ManagedWork (self , work , tensor )
500+ managed_work = _ManagedWork (self , work , [ tensor ] )
484501 fut = managed_work .get_future ()
485- fut = cast (torch .futures .Future [torch .Tensor ], fut )
502+ fut = cast (torch .futures .Future [list [ torch .Tensor ] ], fut )
486503 fut = fut .then (callback )
487504 return managed_work
488505
@@ -1218,7 +1235,7 @@ class _ManagedWork(dist._Work):
12181235 def __init__ (
12191236 self ,
12201237 manager : Manager ,
1221- work : dist ._Work ,
1238+ work : dist ._Work | torchcomms . TorchWork ,
12221239 value : object ,
12231240 ) -> None :
12241241 super ().__init__ ()
@@ -1265,7 +1282,12 @@ def _set_future_callback(
12651282 return
12661283
12671284 managed_fut : _ManagedFuture [object ] = self ._managed_fut_head
1268- managed_fut ._fut = self ._work .get_future ()
1285+ if isinstance (self ._work , dist ._Work ):
1286+ managed_fut ._fut = self ._work .get_future ()
1287+ else :
1288+ fut = torch .futures .Future ()
1289+ fut .set_result (self ._value )
1290+ managed_fut ._fut = fut
12691291 value = self ._value
12701292
12711293 is_future_wrapped = False
@@ -1331,6 +1353,7 @@ def block_current_stream(self, timeout: Optional[timedelta] = None) -> None:
13311353 self ._assert_same_stream ()
13321354
13331355 with get_stream_context (self ._stream ):
1356+ assert isinstance (self ._work , dist ._Work )
13341357 self ._work .block_current_stream ()
13351358
13361359 self ._set_future_callback ()
0 commit comments