Skip to content

Commit 93c230b

Browse files
authored
process_group/gloo: support CUDA tensors (#185)
1 parent 08ccd96 commit 93c230b

File tree

2 files changed

+28
-4
lines changed

2 files changed

+28
-4
lines changed

torchft/process_group.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -560,6 +560,10 @@ def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGro
560560
pg._register_backend(
561561
torch.device("cpu"), ProcessGroup.BackendType.GLOO, backend_class
562562
)
563+
if torch.cuda.is_available():
564+
pg._register_backend(
565+
torch.device("cuda"), ProcessGroup.BackendType.GLOO, backend_class
566+
)
563567
return pg
564568

565569
def getBackendName(self) -> str:

torchft/process_group_test.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def dummy_init_pg() -> None:
6565
def _test_pg(
6666
pg: ProcessGroup,
6767
example_tensor: torch.Tensor = torch.randn((2, 3), dtype=torch.float32),
68+
skip: list[str] = [],
6869
) -> Dict[str, dist._Work]:
6970
"""
7071
Helper function to test a set of collective operations on a given process group.
@@ -124,6 +125,8 @@ def check_tensors(arg: Any) -> None: # pyre-ignore[2]
124125
works: Dict[str, dist._Work] = {}
125126

126127
for coll_str, args in collectives:
128+
if coll_str in skip:
129+
continue
127130
try:
128131
coll = getattr(pg, coll_str)
129132
work = coll(*args)
@@ -496,7 +499,12 @@ def run_reduce_scatter_tensor_coalesced_test(
496499

497500

498501
class ProcessGroupTest(TestCase):
499-
def test_gloo_apis(self) -> None:
502+
@parameterized.expand(["cpu", "cuda"])
503+
def test_gloo_apis(self, device: str) -> None:
504+
if device == "cuda" and not torch.cuda.is_available():
505+
self.skipTest("CUDA is not available")
506+
return
507+
500508
store = TCPStore(
501509
host_name="localhost", port=0, is_master=True, wait_for_workers=False
502510
)
@@ -507,11 +515,23 @@ def test_gloo_apis(self) -> None:
507515

508516
self.assertEqual(pg.size(), 1)
509517

510-
_test_pg(pg)
518+
_test_pg(
519+
pg,
520+
torch.tensor([2], device=device),
521+
skip=(
522+
# https://github.com/pytorch/pytorch/issues/152645
523+
[
524+
"allreduce_coalesced",
525+
"allgather_into_tensor_coalesced",
526+
]
527+
if device == "cuda"
528+
else []
529+
),
530+
)
511531

512-
m = nn.Linear(3, 4)
532+
m = nn.Linear(3, 4).to(device)
513533
m = torch.nn.parallel.DistributedDataParallel(m, process_group=pg)
514-
m(torch.rand(2, 3))
534+
m(torch.rand(2, 3, device=device))
515535

516536
def test_gloo_timeout(self) -> None:
517537
store = TCPStore(

0 commit comments

Comments
 (0)