Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/docs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:

sudo apt-get install -y protobuf-compiler

pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128
pip install --pre torch torchvision torchaudio torchcomms --index-url https://download.pytorch.org/whl/nightly/cu128
pip install .[dev] -v

pip install -r docs/requirements.txt
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/lint.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:

sudo apt-get install -y protobuf-compiler

pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128
pip install --pre torch torchvision torchaudio torchcomms --index-url https://download.pytorch.org/whl/nightly/cu128
pip install .[dev] -v

# install recent version of Rust via rustup
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/unittest-mac.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
steps:
- name: Checkout
uses: actions/checkout@v4

- name: Setup miniconda
uses: pytorch/test-infra/.github/actions/setup-miniconda@main
with:
Expand All @@ -39,7 +39,7 @@ jobs:

python -m pip install --upgrade pip

pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu
pip install --pre torch torchvision torchaudio torchcomms --index-url https://download.pytorch.org/whl/nightly/cpu

pip install -e .[dev] -v

Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/unittest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ jobs:

# Optionally install torch nightly, pulls latest CUDA from pip otherwise
if [ "${{ matrix.torch-version }}" = "nightly" ]; then
pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128
pip install --pre torch torchvision torchaudio torchcomms --index-url https://download.pytorch.org/whl/nightly/cu128
fi
if [ "${{ matrix.torch-version }}" = "test" ]; then
pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/test/cu128
pip install --pre torch torchvision torchaudio torchcomms --index-url https://download.pytorch.org/whl/test/cu128
fi

# Install dependencies
Expand Down
3 changes: 3 additions & 0 deletions torchft/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
ProcessGroupNCCL,
ProcessGroupXCCL,
)
from torchft.torchcomms import TorchCommGloo, TorchCommNCCL

setup_logger("torchft_quorums")
setup_logger("torchft_commits")
Expand All @@ -31,4 +32,6 @@
"ProcessGroupBabyNCCL",
"ProcessGroupBabyXCCL",
"ProcessGroupGloo",
"TorchCommNCCL",
"TorchCommGloo",
)
41 changes: 32 additions & 9 deletions torchft/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,15 @@

import torch
import torch.distributed as dist
import torchcomms
from torch.distributed import ReduceOp, TCPStore
from torch.distributed.distributed_c10d import AllreduceOptions, ReduceOp, Work

from torchft._torchft import ManagerClient, ManagerServer
from torchft.checkpointing import CheckpointTransport, HTTPTransport
from torchft.checkpointing._rwlock import RWLock
from torchft.futures import future_timeout
from torchft.torchcomms import TorchComm
from torchft.utils import get_stream_context, synchronize
from torchft.work import _DummyWork

Expand Down Expand Up @@ -163,7 +165,7 @@ class Manager:

def __init__(
self,
pg: "ProcessGroup",
pg: Union["ProcessGroup", TorchComm],
load_state_dict: Optional[Callable[[T], None]],
state_dict: Optional[Callable[[], T]],
min_replica_size: int,
Expand All @@ -188,6 +190,7 @@ def __init__(
) -> None:
"""
Args:
pg: process group or torchcomms wrapper to use for communication.
load_state_dict: function to load the state dict when recovering
state_dict: function to save the state dict with recovering
min_replica_size: minimum number of replicas on each step
Expand Down Expand Up @@ -456,7 +459,9 @@ def allreduce(
try:
# Run the allreduce async and save the work object so we can wait on
# it later.
# TODO: Support quantization with torchcomms
if should_quantize and IS_TRITON_AVAILABLE:
assert isinstance(self._pg, ProcessGroup)
work = allreduce_quantized(
[tensor],
pg_reduce_op,
Expand All @@ -465,24 +470,36 @@ def allreduce(
torch.accelerator.current_stream(),
)
else:
opts = AllreduceOptions()
opts.reduceOp = pg_reduce_op
work = self._pg.allreduce([tensor], opts)
# Check if we're using torchcomms or ProcessGroup
if isinstance(self._pg, TorchComm):
# Convert PyTorch ReduceOp to torchcomms ReduceOp
if pg_reduce_op == ReduceOp.SUM:
tc_op = torchcomms.ReduceOp.SUM
elif pg_reduce_op == ReduceOp.AVG:
tc_op = torchcomms.ReduceOp.AVG
else:
raise AssertionError("unsupported reduce op")

work = self._pg.allreduce(tensor, tc_op)
else:
opts = AllreduceOptions()
opts.reduceOp = pg_reduce_op
work = self._pg.allreduce([tensor], opts)

# schedule grad normalization as a continuation
# on the Future
@torch.profiler.record_function("torchft::manager::allreduce::callback")
def callback(
fut: torch.futures.Future[torch.Tensor],
fut: torch.futures.Future[list[torch.Tensor]],
) -> torch.Tensor:
nonlocal tensor
if reduce_op == ReduceOp.AVG:
tensor /= num_participants
return tensor

managed_work = _ManagedWork(self, work, tensor)
managed_work = _ManagedWork(self, work, [tensor])
fut = managed_work.get_future()
fut = cast(torch.futures.Future[torch.Tensor], fut)
fut = cast(torch.futures.Future[list[torch.Tensor]], fut)
fut = fut.then(callback)
return managed_work

Expand Down Expand Up @@ -1218,7 +1235,7 @@ class _ManagedWork(dist._Work):
def __init__(
self,
manager: Manager,
work: dist._Work,
work: dist._Work | torchcomms.TorchWork,
value: object,
) -> None:
super().__init__()
Expand Down Expand Up @@ -1265,7 +1282,12 @@ def _set_future_callback(
return

managed_fut: _ManagedFuture[object] = self._managed_fut_head
managed_fut._fut = self._work.get_future()
if isinstance(self._work, dist._Work):
managed_fut._fut = self._work.get_future()
else:
fut = torch.futures.Future()
fut.set_result(self._value)
managed_fut._fut = fut
value = self._value

is_future_wrapped = False
Expand Down Expand Up @@ -1331,6 +1353,7 @@ def block_current_stream(self, timeout: Optional[timedelta] = None) -> None:
self._assert_same_stream()

with get_stream_context(self._stream):
assert isinstance(self._work, dist._Work)
self._work.block_current_stream()

self._set_future_callback()
Expand Down
4 changes: 4 additions & 0 deletions torchft/manager_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,7 @@ def test_allreduce_error(self, client_mock: MagicMock) -> None:
self.assertTrue(manager._errored)
# this should be skipped due to error
manager.allreduce(torch.tensor([1.0])).wait()
# pyre-ignore[16]: _pg is mocked
self.assertEqual(manager._pg.allreduce.call_count, 2)
# pyre-ignore[16]: _pg is mocked
self.assertEqual(manager._pg.allreduce.return_value.get_future.call_count, 1)
Expand All @@ -406,14 +407,17 @@ def test_allreduce_error(self, client_mock: MagicMock) -> None:

bad_fut = torch.futures.Future()
bad_fut.set_exception(RuntimeError("injected failure"))
# pyre-ignore[16]: _pg is mocked
manager._pg.allreduce.return_value.get_future.return_value = bad_fut
manager.allreduce(torch.tensor([1.0])).wait()
# pyre-ignore[16]: _pg is mocked
self.assertEqual(manager._pg.allreduce.return_value.get_future.call_count, 2)
self.assertTrue(manager._errored)
self.assertFalse(manager.should_commit())
self.assertTrue(manager._errored)

# cleanup
# pyre-ignore[16]: _pg is mocked
manager._pg.allreduce.reset_mock(return_value=True)

# recover on next step
Expand Down
2 changes: 2 additions & 0 deletions torchft/process_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -1331,6 +1331,7 @@ class ManagedProcessGroup(ProcessGroupWrapper):
"""

def __init__(self, manager: "Manager") -> None:
assert isinstance(manager._pg, ProcessGroup)
super().__init__(pg=manager._pg)

self._manager = manager
Expand All @@ -1350,6 +1351,7 @@ def size(self) -> int:
return self._manager.num_participants()

def getBackendName(self) -> str:
assert isinstance(self._manager._pg, ProcessGroup)
return self._manager._pg.getBackendName()


Expand Down
Loading
Loading