Skip to content

Commit 9925753

Browse files
tushar00jainfacebook-github-bot
authored andcommitted
integrate torchcomms (#290)
Summary: Pull Request resolved: #290 Differential Revision: D86343575
1 parent 854fb2d commit 9925753

File tree

10 files changed

+744
-22
lines changed

10 files changed

+744
-22
lines changed

.github/workflows/docs.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ jobs:
2525
2626
sudo apt-get install -y protobuf-compiler
2727
28-
pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128
28+
pip install --pre torch torchvision torchaudio torchcomms --index-url https://download.pytorch.org/whl/nightly/cu128
2929
pip install .[dev] -v
3030
3131
pip install -r docs/requirements.txt

.github/workflows/lint.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ jobs:
2323
2424
sudo apt-get install -y protobuf-compiler
2525
26-
pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128
26+
pip install --pre torch torchvision torchaudio torchcomms --index-url https://download.pytorch.org/whl/nightly/cu128
2727
pip install .[dev] -v
2828
2929
# install recent version of Rust via rustup

.github/workflows/unittest-mac.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ jobs:
1212
steps:
1313
- name: Checkout
1414
uses: actions/checkout@v4
15-
15+
1616
- name: Setup miniconda
1717
uses: pytorch/test-infra/.github/actions/setup-miniconda@main
1818
with:
@@ -39,7 +39,7 @@ jobs:
3939
4040
python -m pip install --upgrade pip
4141
42-
pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu
42+
pip install --pre torch torchvision torchaudio torchcomms --index-url https://download.pytorch.org/whl/nightly/cpu
4343
4444
pip install -e .[dev] -v
4545

.github/workflows/unittest.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,10 @@ jobs:
4141
4242
# Optionally install torch nightly, pulls latest CUDA from pip otherwise
4343
if [ "${{ matrix.torch-version }}" = "nightly" ]; then
44-
pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128
44+
pip install --pre torch torchvision torchaudio torchcomms --index-url https://download.pytorch.org/whl/nightly/cu128
4545
fi
4646
if [ "${{ matrix.torch-version }}" = "test" ]; then
47-
pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/test/cu128
47+
pip install --pre torch torchvision torchaudio torchcomms --index-url https://download.pytorch.org/whl/test/cu128
4848
fi
4949
5050
# Install dependencies

torchft/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
ProcessGroupNCCL,
1717
ProcessGroupXCCL,
1818
)
19+
from torchft.torchcomms import TorchCommGloo, TorchCommNCCL
1920

2021
setup_logger("torchft_quorums")
2122
setup_logger("torchft_commits")
@@ -31,4 +32,6 @@
3132
"ProcessGroupBabyNCCL",
3233
"ProcessGroupBabyXCCL",
3334
"ProcessGroupGloo",
35+
"TorchCommNCCL",
36+
"TorchCommGloo",
3437
)

torchft/manager.py

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,15 @@
5151

5252
import torch
5353
import torch.distributed as dist
54+
import torchcomms
5455
from torch.distributed import ReduceOp, TCPStore
5556
from torch.distributed.distributed_c10d import AllreduceOptions, ReduceOp, Work
5657

5758
from torchft._torchft import ManagerClient, ManagerServer
5859
from torchft.checkpointing import CheckpointTransport, HTTPTransport
5960
from torchft.checkpointing._rwlock import RWLock
6061
from torchft.futures import future_timeout
62+
from torchft.torchcomms import TorchComm
6163
from torchft.utils import get_stream_context, synchronize
6264
from torchft.work import _DummyWork
6365

@@ -163,7 +165,7 @@ class Manager:
163165

164166
def __init__(
165167
self,
166-
pg: "ProcessGroup",
168+
pg: Union["ProcessGroup", TorchComm],
167169
load_state_dict: Optional[Callable[[T], None]],
168170
state_dict: Optional[Callable[[], T]],
169171
min_replica_size: int,
@@ -188,6 +190,7 @@ def __init__(
188190
) -> None:
189191
"""
190192
Args:
193+
pg: process group or torchcomms wrapper to use for communication.
191194
load_state_dict: function to load the state dict when recovering
192195
state_dict: function to save the state dict with recovering
193196
min_replica_size: minimum number of replicas on each step
@@ -456,7 +459,9 @@ def allreduce(
456459
try:
457460
# Run the allreduce async and save the work object so we can wait on
458461
# it later.
462+
# TODO: Support quantization with torchcomms
459463
if should_quantize and IS_TRITON_AVAILABLE:
464+
assert isinstance(self._pg, ProcessGroup)
460465
work = allreduce_quantized(
461466
[tensor],
462467
pg_reduce_op,
@@ -465,24 +470,36 @@ def allreduce(
465470
torch.accelerator.current_stream(),
466471
)
467472
else:
468-
opts = AllreduceOptions()
469-
opts.reduceOp = pg_reduce_op
470-
work = self._pg.allreduce([tensor], opts)
473+
# Check if we're using torchcomms or ProcessGroup
474+
if isinstance(self._pg, TorchComm):
475+
# Convert PyTorch ReduceOp to torchcomms ReduceOp
476+
if pg_reduce_op == ReduceOp.SUM:
477+
tc_op = torchcomms.ReduceOp.SUM
478+
elif pg_reduce_op == ReduceOp.AVG:
479+
tc_op = torchcomms.ReduceOp.AVG
480+
else:
481+
raise AssertionError("unsupported reduce op")
482+
483+
work = self._pg.allreduce(tensor, tc_op)
484+
else:
485+
opts = AllreduceOptions()
486+
opts.reduceOp = pg_reduce_op
487+
work = self._pg.allreduce([tensor], opts)
471488

472489
# schedule grad normalization as a continuation
473490
# on the Future
474491
@torch.profiler.record_function("torchft::manager::allreduce::callback")
475492
def callback(
476-
fut: torch.futures.Future[torch.Tensor],
493+
fut: torch.futures.Future[list[torch.Tensor]],
477494
) -> torch.Tensor:
478495
nonlocal tensor
479496
if reduce_op == ReduceOp.AVG:
480497
tensor /= num_participants
481498
return tensor
482499

483-
managed_work = _ManagedWork(self, work, tensor)
500+
managed_work = _ManagedWork(self, work, [tensor])
484501
fut = managed_work.get_future()
485-
fut = cast(torch.futures.Future[torch.Tensor], fut)
502+
fut = cast(torch.futures.Future[list[torch.Tensor]], fut)
486503
fut = fut.then(callback)
487504
return managed_work
488505

@@ -1218,7 +1235,7 @@ class _ManagedWork(dist._Work):
12181235
def __init__(
12191236
self,
12201237
manager: Manager,
1221-
work: dist._Work,
1238+
work: dist._Work | torchcomms.TorchWork,
12221239
value: object,
12231240
) -> None:
12241241
super().__init__()
@@ -1265,7 +1282,12 @@ def _set_future_callback(
12651282
return
12661283

12671284
managed_fut: _ManagedFuture[object] = self._managed_fut_head
1268-
managed_fut._fut = self._work.get_future()
1285+
if isinstance(self._work, dist._Work):
1286+
managed_fut._fut = self._work.get_future()
1287+
else:
1288+
fut = torch.futures.Future()
1289+
fut.set_result(self._value)
1290+
managed_fut._fut = fut
12691291
value = self._value
12701292

12711293
is_future_wrapped = False
@@ -1331,6 +1353,7 @@ def block_current_stream(self, timeout: Optional[timedelta] = None) -> None:
13311353
self._assert_same_stream()
13321354

13331355
with get_stream_context(self._stream):
1356+
assert isinstance(self._work, dist._Work)
13341357
self._work.block_current_stream()
13351358

13361359
self._set_future_callback()

torchft/manager_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,7 @@ def test_allreduce_error(self, client_mock: MagicMock) -> None:
387387
self.assertTrue(manager._errored)
388388
# this should be skipped due to error
389389
manager.allreduce(torch.tensor([1.0])).wait()
390+
# pyre-ignore[16]: _pg is mocked
390391
self.assertEqual(manager._pg.allreduce.call_count, 2)
391392
# pyre-ignore[16]: _pg is mocked
392393
self.assertEqual(manager._pg.allreduce.return_value.get_future.call_count, 1)
@@ -406,14 +407,17 @@ def test_allreduce_error(self, client_mock: MagicMock) -> None:
406407

407408
bad_fut = torch.futures.Future()
408409
bad_fut.set_exception(RuntimeError("injected failure"))
410+
# pyre-ignore[16]: _pg is mocked
409411
manager._pg.allreduce.return_value.get_future.return_value = bad_fut
410412
manager.allreduce(torch.tensor([1.0])).wait()
413+
# pyre-ignore[16]: _pg is mocked
411414
self.assertEqual(manager._pg.allreduce.return_value.get_future.call_count, 2)
412415
self.assertTrue(manager._errored)
413416
self.assertFalse(manager.should_commit())
414417
self.assertTrue(manager._errored)
415418

416419
# cleanup
420+
# pyre-ignore[16]: _pg is mocked
417421
manager._pg.allreduce.reset_mock(return_value=True)
418422

419423
# recover on next step

torchft/process_group.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1331,6 +1331,7 @@ class ManagedProcessGroup(ProcessGroupWrapper):
13311331
"""
13321332

13331333
def __init__(self, manager: "Manager") -> None:
1334+
assert isinstance(manager._pg, ProcessGroup)
13341335
super().__init__(pg=manager._pg)
13351336

13361337
self._manager = manager
@@ -1350,6 +1351,7 @@ def size(self) -> int:
13501351
return self._manager.num_participants()
13511352

13521353
def getBackendName(self) -> str:
1354+
assert isinstance(self._manager._pg, ProcessGroup)
13531355
return self._manager._pg.getBackendName()
13541356

13551357

0 commit comments

Comments
 (0)