1919import logging
2020import queue
2121import threading
22+ from dataclasses import dataclass
2223from datetime import timedelta
23- from typing import TYPE_CHECKING , Any , Dict , List , Optional , Tuple , Type , Union
24+ from typing import (
25+ TYPE_CHECKING ,
26+ Any ,
27+ Callable ,
28+ Dict ,
29+ List ,
30+ Optional ,
31+ Tuple ,
32+ Type ,
33+ TypeVar ,
34+ Union ,
35+ cast ,
36+ )
2437
2538import torch
2639import torch .distributed as dist
2942# pyre-fixme[21]: no attribute ProcessGroupNCCL
3043# pyre-fixme[21]: no attribute ProcessGroupGloo
3144from torch .distributed import (
32- BroadcastOptions ,
3345 DeviceMesh ,
3446 PrefixStore ,
3547 ProcessGroup as BaseProcessGroup ,
4052 get_rank ,
4153 init_device_mesh ,
4254)
43- from torch .distributed .distributed_c10d import Work , _world
55+ from torch .distributed .distributed_c10d import (
56+ AllgatherOptions ,
57+ AllreduceOptions ,
58+ BroadcastOptions ,
59+ ReduceOp ,
60+ Work ,
61+ _world ,
62+ )
4463from torch .futures import Future
4564
4665if TYPE_CHECKING :
5473_FUTURE_EXCEPTION = "fut_exception"
5574
5675
76+ T = TypeVar ("T" )
77+
78+
5779def _get (q : mp .Queue , timeout : Union [float , timedelta ]) -> object :
5880 """
5981 Gets an item from a queue with a timeout. If the timeout is exceeded then
@@ -122,15 +144,17 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None:
122144 raise NotImplementedError ("not implemented" )
123145
124146 # pyre-fixme[14]: inconsistent override
125- def allreduce (self , tensors : List [torch .Tensor ], opts : object ) -> Work :
147+ def allreduce (
148+ self , tensors : List [torch .Tensor ], opts : Union [AllreduceOptions , ReduceOp ]
149+ ) -> Work :
126150 raise NotImplementedError ("not implemented" )
127151
128152 # pyre-fixme[14]: inconsistent override
129153 def allgather (
130154 self ,
131155 output_tensors : List [List [torch .Tensor ]],
132156 input_tensor : List [torch .Tensor ],
133- opts : object ,
157+ opts : AllgatherOptions ,
134158 ) -> Work :
135159 """
136160 Gathers tensors from the whole group in a list.
@@ -140,7 +164,9 @@ def allgather(
140164 raise NotImplementedError ("not implemented" )
141165
142166 # pyre-fixme[14]: inconsistent override
143- def broadcast (self , tensor_list : List [torch .Tensor ], opts : object ) -> Work :
167+ def broadcast (
168+ self , tensor_list : List [torch .Tensor ], opts : BroadcastOptions
169+ ) -> Work :
144170 """
145171 Broadcasts the tensor to the whole group.
146172
@@ -567,6 +593,9 @@ def wait(self, timeout: Optional[timedelta] = None) -> bool:
567593 def get_future (self ) -> Future [object ]:
568594 return self ._pg ._get_future (self ._op_id )
569595
596+ def __del__ (self ) -> None :
597+ self ._tx .put (("del" , self ._op_id ), timeout = self ._timeout )
598+
570599
571600class _BabyWorkNCCL (_BabyWork ):
572601 def wait (self , timeout : Optional [timedelta ] = None ) -> bool :
@@ -695,15 +724,18 @@ def _worker(
695724 cmd = op [0 ]
696725 if cmd == "func" :
697726 func_name , args , kwargs = op [1 :]
727+ args = _PickleSafeOptions .unsafe_args (args )
698728 fn = getattr (pg , func_name )
699729 work [next_op_id ] = fn (* args , ** kwargs )
700730 tx .put (next_op_id )
701731 next_op_id += 1
702732 elif cmd == "wait" :
703733 op_id : int = op [1 ]
704734 work [op_id ].wait ()
705- del work [op_id ]
706735 tx .put (op_id )
736+ elif cmd == "del" :
737+ op_id : int = op [1 ]
738+ del work [op_id ]
707739 elif cmd == "future" :
708740 op_id : int = op [1 ]
709741
@@ -731,6 +763,8 @@ def callback(fut: Future[object]) -> None:
731763
732764 del work [op_id ]
733765 tx .put ((op_id , event ))
766+ elif cmd == "num_active_work" :
767+ tx .put (len (work ))
734768 else :
735769 raise ValueError (f"unknown cmd: { cmd } " )
736770
@@ -775,7 +809,10 @@ def _run_func(self, func: str, *args: object, **kwargs: object) -> Work:
775809 assert rx is not None
776810 assert tx is not None
777811
778- tx .put (("func" , func , args , kwargs ), timeout = self ._timeout )
812+ tx .put (
813+ ("func" , func , _PickleSafeOptions .safe_args (args ), kwargs ),
814+ timeout = self ._timeout ,
815+ )
779816
780817 op_id = _get (rx , self ._timeout )
781818 assert isinstance (op_id , int ), f"invalid return { op_id } "
@@ -784,7 +821,11 @@ def _run_func(self, func: str, *args: object, **kwargs: object) -> Work:
784821 pg = self , tx = tx , rx = rx , op_id = op_id , timeout = self ._timeout
785822 )
786823
787- def allreduce (self , tensors : List [torch .Tensor ], opts : object ) -> Work :
824+ def allreduce (
825+ self ,
826+ tensors : List [torch .Tensor ],
827+ opts : Union [dist .AllreduceOptions , dist .ReduceOp ],
828+ ) -> Work :
788829 assert isinstance (tensors , list ), "input must be list"
789830
790831 for tensor in tensors :
@@ -793,9 +834,90 @@ def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
793834
794835 return self ._run_func ("allreduce" , tensors , opts )
795836
837+ def allgather (
838+ self ,
839+ output_tensors : List [List [torch .Tensor ]],
840+ input_tensor : List [torch .Tensor ],
841+ opts : AllgatherOptions ,
842+ ) -> Work :
843+ assert isinstance (output_tensors , list ), "input must be list"
844+ assert isinstance (input_tensor , list ), "input must be list"
845+
846+ for tensor_list in output_tensors :
847+ for tensor in tensor_list :
848+ if not tensor .is_shared ():
849+ tensor .share_memory_ ()
850+
851+ for tensor in input_tensor :
852+ if not tensor .is_shared ():
853+ tensor .share_memory_ ()
854+
855+ return self ._run_func ("allgather" , output_tensors , input_tensor , opts )
856+
857+ def broadcast (
858+ self ,
859+ tensor_list : List [torch .Tensor ],
860+ opts : BroadcastOptions ,
861+ ) -> Work :
862+ assert isinstance (tensor_list , list ), "input must be list"
863+
864+ for tensor in tensor_list :
865+ if not tensor .is_shared ():
866+ tensor .share_memory_ ()
867+
868+ return self ._run_func ("broadcast" , tensor_list , opts )
869+
796870 def size (self ) -> int :
797871 return self ._world_size
798872
873+ def num_active_work (self ) -> int :
874+ assert self ._tx is not None
875+ self ._tx .put (("num_active_work" ,), timeout = self ._timeout )
876+
877+ assert self ._rx is not None
878+ return cast (int , _get (self ._rx , self ._timeout ))
879+
880+
881+ @dataclass
882+ class _PickleSafeOptions :
883+ func : Callable [[], object ]
884+ fields : Dict [str , object ]
885+
886+ @classmethod
887+ def safe_args (cls , args : T ) -> T :
888+ if isinstance (args , tuple ):
889+ return tuple (cls .safe_args (arg ) for arg in args )
890+ elif isinstance (args , list ):
891+ return [cls .safe_args (arg ) for arg in args ]
892+ elif isinstance (args , (AllreduceOptions , AllgatherOptions , BroadcastOptions )):
893+ return cls .from_torch (args )
894+ else :
895+ return args
896+
897+ @classmethod
898+ def unsafe_args (cls , args : T ) -> T :
899+ if isinstance (args , tuple ):
900+ return tuple (cls .unsafe_args (arg ) for arg in args )
901+ elif isinstance (args , list ):
902+ return [cls .unsafe_args (arg ) for arg in args ]
903+ elif isinstance (args , cls ):
904+ return args .to_torch ()
905+ else :
906+ return args
907+
908+ @classmethod
909+ def from_torch (cls , opts : object ) -> "_PickleSafeOptions" :
910+ return cls (
911+ func = opts .__class__ ,
912+ fields = {k : getattr (opts , k ) for k in dir (opts ) if not k .startswith ("_" )},
913+ )
914+
915+ def to_torch (self ) -> object :
916+ opts = self .func ()
917+ for k , v in self .fields .items ():
918+ setattr (opts , k , v )
919+ return opts
920+
799921
800922class ProcessGroupBabyGloo (ProcessGroupBaby ):
801923 """
0 commit comments