5151
5252import torch
5353import torch .distributed as dist
54+ import torchcomms
5455from torch .distributed import ReduceOp , TCPStore
5556from torch .distributed .distributed_c10d import AllreduceOptions , ReduceOp , Work
5657
6364
6465if TYPE_CHECKING :
6566 from torchft .process_group import ProcessGroup
67+ from torchft .torchcomms import TorchComm
6668
6769IS_TRITON_AVAILABLE = True
6870try :
@@ -163,7 +165,7 @@ class Manager:
163165
164166 def __init__ (
165167 self ,
166- pg : "ProcessGroup" ,
168+ pg : Union [ "ProcessGroup" , "TorchComm" ] ,
167169 load_state_dict : Optional [Callable [[T ], None ]],
168170 state_dict : Optional [Callable [[], T ]],
169171 min_replica_size : int ,
@@ -188,6 +190,7 @@ def __init__(
188190 ) -> None :
189191 """
190192 Args:
193+ pg: process group or torchcomms wrapper to use for communication.
191194 load_state_dict: function to load the state dict when recovering
192195 state_dict: function to save the state dict with recovering
193196 min_replica_size: minimum number of replicas on each step
@@ -456,7 +459,9 @@ def allreduce(
456459 try :
457460 # Run the allreduce async and save the work object so we can wait on
458461 # it later.
462+ # TODO: Support quantization with torchcomms
459463 if should_quantize and IS_TRITON_AVAILABLE :
464+ assert isinstance (self ._pg , ProcessGroup )
460465 work = allreduce_quantized (
461466 [tensor ],
462467 pg_reduce_op ,
@@ -465,9 +470,21 @@ def allreduce(
465470 torch .accelerator .current_stream (),
466471 )
467472 else :
468- opts = AllreduceOptions ()
469- opts .reduceOp = pg_reduce_op
470- work = self ._pg .allreduce ([tensor ], opts )
473+ # Check if we're using torchcomms or ProcessGroup
474+ if isinstance (self ._pg , TorchComm ):
475+ # Convert PyTorch ReduceOp to torchcomms ReduceOp
476+ if pg_reduce_op == ReduceOp .SUM :
477+ tc_op = torchcomms .ReduceOp .SUM
478+ elif pg_reduce_op == ReduceOp .AVG :
479+ tc_op = torchcomms .ReduceOp .AVG
480+ else :
481+ raise AssertionError ("unsupported reduce op" )
482+
483+ work = self ._pg .allreduce (tensor , tc_op )
484+ else :
485+ opts = AllreduceOptions ()
486+ opts .reduceOp = pg_reduce_op
487+ work = self ._pg .allreduce ([tensor ], opts )
471488
472489 # schedule grad normalization as a continuation
473490 # on the Future
0 commit comments