@@ -611,6 +611,108 @@ def reduce_scatter_tensor_coalesced(
611
611
)
612
612
613
613
614
+ class _ParallelWork (Work ):
615
+ def __init__ (self , works : List [Work ]) -> None :
616
+ super ().__init__ ()
617
+ self ._works = works
618
+
619
+ def wait (self , timeout : Optional [timedelta ] = None ) -> bool :
620
+ for work in self ._works :
621
+ if timeout is not None :
622
+ work .wait (timeout = timeout )
623
+ else :
624
+ work .wait ()
625
+ return True
626
+
627
+ def get_future (self ) -> torch .futures .Future [object ]:
628
+ futures = [work .get_future () for work in self ._works ]
629
+ return torch .futures .collect_all (futures )
630
+
631
+
632
+ class ParallelProcessGroup (ProcessGroupWrapper ):
633
+ def __init__ (
634
+ self ,
635
+ base : ProcessGroupWrapper ,
636
+ timeout : timedelta = timedelta (seconds = 60 ),
637
+ count : int = 10 ,
638
+ ) -> None :
639
+ super ().__init__ (timeout = timeout )
640
+
641
+ self ._count = count
642
+ self ._pgs = []
643
+
644
+ self ._create_pg = base ._create_pg
645
+
646
+ def configure (self , store_addr : str , rank : int , world_size : int ) -> None :
647
+ # abort if already initialized
648
+ self .abort ()
649
+
650
+ for i in range (self ._count ):
651
+ store = create_store_client (
652
+ f"{ store_addr } /parallel{ i } " , timeout = self ._timeout
653
+ )
654
+
655
+ self ._pgs .append (self ._create_pg (store , rank , world_size ))
656
+
657
+ self ._pg = self ._pgs [0 ]
658
+
659
+ def _split_tensors (self , tensors : List [torch .Tensor ]) -> List [List [torch .Tensor ]]:
660
+ if not isinstance (tensors , (list , tuple )):
661
+ tensors = [tensors ]
662
+
663
+ tensor_lists = [[] for _ in range (self ._count )]
664
+ for t in tensors :
665
+ chunks = torch .tensor_split (t .view (- 1 ), self ._count , dim = 0 )
666
+ for i , chunk in enumerate (chunks ):
667
+ tensor_lists [i ].append (chunk )
668
+
669
+ return tensor_lists
670
+
671
+ def allreduce (self , tensors : List [torch .Tensor ], opts : object ) -> Work :
672
+ tensor_lists = self ._split_tensors (tensors )
673
+
674
+ with self ._run_context ():
675
+ works = []
676
+ for i in range (self ._count ):
677
+ works .append (
678
+ self ._pgs [i ].allreduce (tensor_lists [i ], self ._opts_hook (opts ))
679
+ )
680
+
681
+ return self ._wrap_work (_ParallelWork (works ), opts )
682
+
683
+ def reduce (self , tensors : List [torch .Tensor ], dst : int , opts : object ) -> Work :
684
+ tensor_lists = self ._split_tensors (tensors )
685
+
686
+ with self ._run_context ():
687
+ works = []
688
+ for i in range (self ._count ):
689
+ works .append (
690
+ self ._pgs [i ].reduce (tensor_lists [i ], dst , self ._opts_hook (opts ))
691
+ )
692
+
693
+ return self ._wrap_work (_ParallelWork (works ), opts )
694
+
695
+ def send (self , tensors : List [torch .Tensor ], dst_rank : int , tag : int ) -> Work :
696
+ tensor_lists = self ._split_tensors (tensors )
697
+
698
+ with self ._run_context ():
699
+ works = []
700
+ for i in range (self ._count ):
701
+ works .append (self ._pgs [i ].send (tensor_lists [i ], dst_rank , tag ))
702
+
703
+ return self ._wrap_work (_ParallelWork (works ), None )
704
+
705
+ def recv (self , tensors : List [torch .Tensor ], src_rank : int , tag : int ) -> Work :
706
+ tensor_lists = self ._split_tensors (tensors )
707
+
708
+ with self ._run_context ():
709
+ works = []
710
+ for i in range (self ._count ):
711
+ works .append (self ._pgs [i ].recv (tensor_lists [i ], src_rank , tag ))
712
+
713
+ return self ._wrap_work (_ParallelWork (works ), None )
714
+
715
+
614
716
class _WorkCUDATimeout (Work ):
615
717
def __init__ (self , pg : ProcessGroup , work : Work , timeout : timedelta ) -> None :
616
718
super ().__init__ ()
0 commit comments