Skip to content

Commit 4dbd816

Browse files
tushar00jainfacebook-github-bot
authored andcommitted
integrate torchcomms
Differential Revision: D86343575
1 parent 15bd67c commit 4dbd816

File tree

3 files changed

+703
-4
lines changed

3 files changed

+703
-4
lines changed

torchft/manager.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151

5252
import torch
5353
import torch.distributed as dist
54+
import torchcomms
5455
from torch.distributed import ReduceOp, TCPStore
5556
from torch.distributed.distributed_c10d import AllreduceOptions, ReduceOp, Work
5657

@@ -63,6 +64,7 @@
6364

6465
if TYPE_CHECKING:
6566
from torchft.process_group import ProcessGroup
67+
from torchft.torchcomms import TorchComm
6668

6769
IS_TRITON_AVAILABLE = True
6870
try:
@@ -163,7 +165,7 @@ class Manager:
163165

164166
def __init__(
165167
self,
166-
pg: "ProcessGroup",
168+
pg: Union["ProcessGroup", "TorchComm"],
167169
load_state_dict: Optional[Callable[[T], None]],
168170
state_dict: Optional[Callable[[], T]],
169171
min_replica_size: int,
@@ -188,6 +190,7 @@ def __init__(
188190
) -> None:
189191
"""
190192
Args:
193+
pg: process group or torchcomms wrapper to use for communication.
191194
load_state_dict: function to load the state dict when recovering
192195
state_dict: function to save the state dict with recovering
193196
min_replica_size: minimum number of replicas on each step
@@ -456,7 +459,9 @@ def allreduce(
456459
try:
457460
# Run the allreduce async and save the work object so we can wait on
458461
# it later.
462+
# TODO: Support quantization with torchcomms
459463
if should_quantize and IS_TRITON_AVAILABLE:
464+
assert isinstance(self._pg, ProcessGroup)
460465
work = allreduce_quantized(
461466
[tensor],
462467
pg_reduce_op,
@@ -465,9 +470,21 @@ def allreduce(
465470
torch.accelerator.current_stream(),
466471
)
467472
else:
468-
opts = AllreduceOptions()
469-
opts.reduceOp = pg_reduce_op
470-
work = self._pg.allreduce([tensor], opts)
473+
# Check if we're using torchcomms or ProcessGroup
474+
if isinstance(self._pg, TorchComm):
475+
# Convert PyTorch ReduceOp to torchcomms ReduceOp
476+
if pg_reduce_op == ReduceOp.SUM:
477+
tc_op = torchcomms.ReduceOp.SUM
478+
elif pg_reduce_op == ReduceOp.AVG:
479+
tc_op = torchcomms.ReduceOp.AVG
480+
else:
481+
raise AssertionError("unsupported reduce op")
482+
483+
work = self._pg.allreduce(tensor, tc_op)
484+
else:
485+
opts = AllreduceOptions()
486+
opts.reduceOp = pg_reduce_op
487+
work = self._pg.allreduce([tensor], opts)
471488

472489
# schedule grad normalization as a continuation
473490
# on the Future

torchft/process_group.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1331,6 +1331,7 @@ class ManagedProcessGroup(ProcessGroupWrapper):
13311331
"""
13321332

13331333
def __init__(self, manager: "Manager") -> None:
1334+
assert isinstance(manager._pg, ProcessGroup)
13341335
super().__init__(pg=manager._pg)
13351336

13361337
self._manager = manager
@@ -1350,6 +1351,7 @@ def size(self) -> int:
13501351
return self._manager.num_participants()
13511352

13521353
def getBackendName(self) -> str:
1354+
assert isinstance(self._manager._pg, ProcessGroup)
13531355
return self._manager._pg.getBackendName()
13541356

13551357

0 commit comments

Comments
 (0)