Skip to content

Commit 0ba1a87

Browse files
committed
ParallelProcessGroup
1 parent b84c5a6 commit 0ba1a87

File tree

2 files changed

+121
-0
lines changed

2 files changed

+121
-0
lines changed

torchft/process_group.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -611,6 +611,108 @@ def reduce_scatter_tensor_coalesced(
611611
)
612612

613613

614+
class _ParallelWork(Work):
615+
def __init__(self, works: List[Work]) -> None:
616+
super().__init__()
617+
self._works = works
618+
619+
def wait(self, timeout: Optional[timedelta] = None) -> bool:
620+
for work in self._works:
621+
if timeout is not None:
622+
work.wait(timeout=timeout)
623+
else:
624+
work.wait()
625+
return True
626+
627+
def get_future(self) -> torch.futures.Future[object]:
628+
futures = [work.get_future() for work in self._works]
629+
return torch.futures.collect_all(futures)
630+
631+
632+
class ParallelProcessGroup(ProcessGroupWrapper):
633+
def __init__(
634+
self,
635+
base: ProcessGroupWrapper,
636+
timeout: timedelta = timedelta(seconds=60),
637+
count: int = 10,
638+
) -> None:
639+
super().__init__(timeout=timeout)
640+
641+
self._count = count
642+
self._pgs = []
643+
644+
self._create_pg = base._create_pg
645+
646+
def configure(self, store_addr: str, rank: int, world_size: int) -> None:
647+
# abort if already initialized
648+
self.abort()
649+
650+
for i in range(self._count):
651+
store = create_store_client(
652+
f"{store_addr}/parallel{i}", timeout=self._timeout
653+
)
654+
655+
self._pgs.append(self._create_pg(store, rank, world_size))
656+
657+
self._pg = self._pgs[0]
658+
659+
def _split_tensors(self, tensors: List[torch.Tensor]) -> List[List[torch.Tensor]]:
660+
if not isinstance(tensors, (list, tuple)):
661+
tensors = [tensors]
662+
663+
tensor_lists = [[] for _ in range(self._count)]
664+
for t in tensors:
665+
chunks = torch.tensor_split(t.view(-1), self._count, dim=0)
666+
for i, chunk in enumerate(chunks):
667+
tensor_lists[i].append(chunk)
668+
669+
return tensor_lists
670+
671+
def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
672+
tensor_lists = self._split_tensors(tensors)
673+
674+
with self._run_context():
675+
works = []
676+
for i in range(self._count):
677+
works.append(
678+
self._pgs[i].allreduce(tensor_lists[i], self._opts_hook(opts))
679+
)
680+
681+
return self._wrap_work(_ParallelWork(works), opts)
682+
683+
def reduce(self, tensors: List[torch.Tensor], dst: int, opts: object) -> Work:
684+
tensor_lists = self._split_tensors(tensors)
685+
686+
with self._run_context():
687+
works = []
688+
for i in range(self._count):
689+
works.append(
690+
self._pgs[i].reduce(tensor_lists[i], dst, self._opts_hook(opts))
691+
)
692+
693+
return self._wrap_work(_ParallelWork(works), opts)
694+
695+
def send(self, tensors: List[torch.Tensor], dst_rank: int, tag: int) -> Work:
696+
tensor_lists = self._split_tensors(tensors)
697+
698+
with self._run_context():
699+
works = []
700+
for i in range(self._count):
701+
works.append(self._pgs[i].send(tensor_lists[i], dst_rank, tag))
702+
703+
return self._wrap_work(_ParallelWork(works), None)
704+
705+
def recv(self, tensors: List[torch.Tensor], src_rank: int, tag: int) -> Work:
706+
tensor_lists = self._split_tensors(tensors)
707+
708+
with self._run_context():
709+
works = []
710+
for i in range(self._count):
711+
works.append(self._pgs[i].recv(tensor_lists[i], src_rank, tag))
712+
713+
return self._wrap_work(_ParallelWork(works), None)
714+
715+
614716
class _WorkCUDATimeout(Work):
615717
def __init__(self, pg: ProcessGroup, work: Work, timeout: timedelta) -> None:
616718
super().__init__()

torchft/process_group_test.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from torchft.process_group import (
4141
ErrorSwallowingProcessGroupWrapper,
4242
ManagedProcessGroup,
43+
ParallelProcessGroup,
4344
ProcessGroup,
4445
ProcessGroupBabyGloo,
4546
ProcessGroupBabyNCCL,
@@ -690,6 +691,24 @@ def test_baby_gloo_apis(self) -> None:
690691
with self.assertRaisesRegex(OSError, "handle is closed"):
691692
a.allreduce([t], AllreduceOptions()).wait()
692693

694+
def test_parallel_gloo_apis(self) -> None:
695+
store = TCPStore(
696+
host_name="localhost", port=0, is_master=True, wait_for_workers=False
697+
)
698+
699+
store_addr = f"localhost:{store.port}/prefix"
700+
701+
a = ParallelProcessGroup(
702+
base=ProcessGroupGloo(),
703+
count=4,
704+
)
705+
a.configure(store_addr, 0, 1)
706+
707+
_test_pg(
708+
a,
709+
skip=("reduce_scatter_tensor_coalesced"),
710+
)
711+
693712
# pyre-fixme[56]: Pyre was not able to infer the type of argument
694713
@skipUnless(torch.cuda.is_available(), "needs CUDA")
695714
def test_baby_nccl_apis(self) -> None:

0 commit comments

Comments
 (0)