Skip to content

Commit 7f87b9b

Browse files
tushar00jainfacebook-github-bot
authored andcommitted
integrate torchcomms (#290)
Summary: Pull Request resolved: #290 Differential Revision: D86343575
1 parent 854fb2d commit 7f87b9b

File tree

9 files changed

+724
-16
lines changed

9 files changed

+724
-16
lines changed

.github/workflows/docs.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ jobs:
2525
2626
sudo apt-get install -y protobuf-compiler
2727
28-
pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128
28+
pip install --pre torch torchvision torchaudio torchcomms --index-url https://download.pytorch.org/whl/nightly/cu128
2929
pip install .[dev] -v
3030
3131
pip install -r docs/requirements.txt

.github/workflows/lint.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ jobs:
2323
2424
sudo apt-get install -y protobuf-compiler
2525
26-
pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128
26+
pip install --pre torch torchvision torchaudio torchcomms --index-url https://download.pytorch.org/whl/nightly/cu128
2727
pip install .[dev] -v
2828
2929
# install recent version of Rust via rustup

.github/workflows/unittest-mac.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ jobs:
1212
steps:
1313
- name: Checkout
1414
uses: actions/checkout@v4
15-
15+
1616
- name: Setup miniconda
1717
uses: pytorch/test-infra/.github/actions/setup-miniconda@main
1818
with:
@@ -39,7 +39,7 @@ jobs:
3939
4040
python -m pip install --upgrade pip
4141
42-
pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu
42+
pip install --pre torch torchvision torchaudio torchcomms --index-url https://download.pytorch.org/whl/nightly/cpu
4343
4444
pip install -e .[dev] -v
4545

.github/workflows/unittest.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,10 @@ jobs:
4141
4242
# Optionally install torch nightly, pulls latest CUDA from pip otherwise
4343
if [ "${{ matrix.torch-version }}" = "nightly" ]; then
44-
pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128
44+
pip install --pre torch torchvision torchaudio torchcomms --index-url https://download.pytorch.org/whl/nightly/cu128
4545
fi
4646
if [ "${{ matrix.torch-version }}" = "test" ]; then
47-
pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/test/cu128
47+
pip install --pre torch torchvision torchaudio torchcomms --index-url https://download.pytorch.org/whl/test/cu128
4848
fi
4949
5050
# Install dependencies

torchft/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
ProcessGroupNCCL,
1717
ProcessGroupXCCL,
1818
)
19+
from torchft.torchcomms import TorchCommGloo, TorchCommNCCL
1920

2021
setup_logger("torchft_quorums")
2122
setup_logger("torchft_commits")
@@ -31,4 +32,6 @@
3132
"ProcessGroupBabyNCCL",
3233
"ProcessGroupBabyXCCL",
3334
"ProcessGroupGloo",
35+
"TorchCommNCCL",
36+
"TorchCommGloo",
3437
)

torchft/manager.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151

5252
import torch
5353
import torch.distributed as dist
54+
import torchcomms
5455
from torch.distributed import ReduceOp, TCPStore
5556
from torch.distributed.distributed_c10d import AllreduceOptions, ReduceOp, Work
5657

@@ -63,6 +64,7 @@
6364

6465
if TYPE_CHECKING:
6566
from torchft.process_group import ProcessGroup
67+
from torchft.torchcomms import TorchComm
6668

6769
IS_TRITON_AVAILABLE = True
6870
try:
@@ -163,7 +165,7 @@ class Manager:
163165

164166
def __init__(
165167
self,
166-
pg: "ProcessGroup",
168+
pg: Union["ProcessGroup", "TorchComm"],
167169
load_state_dict: Optional[Callable[[T], None]],
168170
state_dict: Optional[Callable[[], T]],
169171
min_replica_size: int,
@@ -188,6 +190,7 @@ def __init__(
188190
) -> None:
189191
"""
190192
Args:
193+
pg: process group or torchcomms wrapper to use for communication.
191194
load_state_dict: function to load the state dict when recovering
192195
state_dict: function to save the state dict with recovering
193196
min_replica_size: minimum number of replicas on each step
@@ -456,7 +459,9 @@ def allreduce(
456459
try:
457460
# Run the allreduce async and save the work object so we can wait on
458461
# it later.
462+
# TODO: Support quantization with torchcomms
459463
if should_quantize and IS_TRITON_AVAILABLE:
464+
assert isinstance(self._pg, ProcessGroup)
460465
work = allreduce_quantized(
461466
[tensor],
462467
pg_reduce_op,
@@ -465,9 +470,21 @@ def allreduce(
465470
torch.accelerator.current_stream(),
466471
)
467472
else:
468-
opts = AllreduceOptions()
469-
opts.reduceOp = pg_reduce_op
470-
work = self._pg.allreduce([tensor], opts)
473+
# Check if we're using torchcomms or ProcessGroup
474+
if isinstance(self._pg, TorchComm):
475+
# Convert PyTorch ReduceOp to torchcomms ReduceOp
476+
if pg_reduce_op == ReduceOp.SUM:
477+
tc_op = torchcomms.ReduceOp.SUM
478+
elif pg_reduce_op == ReduceOp.AVG:
479+
tc_op = torchcomms.ReduceOp.AVG
480+
else:
481+
raise AssertionError("unsupported reduce op")
482+
483+
work = self._pg.allreduce(tensor, tc_op)
484+
else:
485+
opts = AllreduceOptions()
486+
opts.reduceOp = pg_reduce_op
487+
work = self._pg.allreduce([tensor], opts)
471488

472489
# schedule grad normalization as a continuation
473490
# on the Future

torchft/process_group.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1331,6 +1331,7 @@ class ManagedProcessGroup(ProcessGroupWrapper):
13311331
"""
13321332

13331333
def __init__(self, manager: "Manager") -> None:
1334+
assert isinstance(manager._pg, ProcessGroup)
13341335
super().__init__(pg=manager._pg)
13351336

13361337
self._manager = manager
@@ -1350,6 +1351,7 @@ def size(self) -> int:
13501351
return self._manager.num_participants()
13511352

13521353
def getBackendName(self) -> str:
1354+
assert isinstance(self._manager._pg, ProcessGroup)
13531355
return self._manager._pg.getBackendName()
13541356

13551357

0 commit comments

Comments
 (0)