1919import logging
2020import queue
2121import threading
22+ from contextlib import contextmanager , nullcontext
2223from dataclasses import dataclass
2324from datetime import timedelta
2425from typing import (
2526 TYPE_CHECKING ,
2627 Any ,
2728 Callable ,
2829 Dict ,
30+ Generator ,
2931 List ,
3032 Optional ,
3133 Tuple ,
32- Type ,
3334 TypeVar ,
3435 Union ,
3536 cast ,
5859 BroadcastOptions ,
5960 ReduceOp ,
6061 Work ,
61- _world ,
6262)
6363from torch .futures import Future
64+ from torch .utils ._pytree import tree_any
6465
6566if TYPE_CHECKING :
6667 from torchft .manager import Manager
@@ -586,29 +587,52 @@ def __init__(
586587 self ._timeout = timeout
587588
588589 def wait (self , timeout : Optional [timedelta ] = None ) -> bool :
590+ self ._pg ._assert_alive ()
591+
589592 self ._tx .put (("wait" , self ._op_id ), timeout = self ._timeout )
590- assert _get (self ._rx , self ._timeout ) == self ._op_id
593+ op_id , event = cast (
594+ Tuple [int , Optional [torch .cuda .Event ]],
595+ _get (self ._rx , timeout or self ._timeout ),
596+ )
597+ assert op_id == self ._op_id
598+ if event is not None :
599+ event .wait ()
591600 return True
592601
602+ def synchronize (self ) -> None :
603+ # TODO: No one seems to use this and NCCL wait already only waits the
604+ # stream and is non-blocking on the CPU side so no real need for a
605+ # separate call.
606+ raise NotImplementedError ("not implemented" )
607+
593608 def get_future (self ) -> Future [object ]:
594609 return self ._pg ._get_future (self ._op_id )
595610
596611 def __del__ (self ) -> None :
597612 self ._tx .put (("del" , self ._op_id ), timeout = self ._timeout )
598613
599614
600- class _BabyWorkNCCL (_BabyWork ):
601- def wait (self , timeout : Optional [timedelta ] = None ) -> bool :
602- self ._tx .put (("synchronize" , self ._op_id ), timeout = self ._timeout )
603- # pyre-fixme[23]: unable to unpack into 2 values
604- op_id , event = _get (self ._rx , self ._timeout )
605- assert op_id == self ._op_id
606- assert isinstance (event , torch .cuda .Event )
615+ def _is_any_cuda (obj : object ) -> bool :
616+ """
617+ Returns true if any of the tensors in the object are CUDA tensors.
607618
608- # Wait on Event makes the stream wait but not the CPU thread.
609- event .wait ()
619+ Supports lists, tuples, dicts, and tensors.
620+ """
621+ return tree_any (lambda obj : isinstance (obj , torch .Tensor ) and obj .is_cuda , obj )
610622
611- return True
623+
624+ @dataclass
625+ class _OpMetadata :
626+ work : Work
627+ stream : Optional [torch .cuda .Stream ]
628+
629+ @contextmanager
630+ def set_stream (self ) -> Generator [None , None , None ]:
631+ if self .stream is not None :
632+ with torch .cuda .stream (self .stream ):
633+ yield
634+ else :
635+ yield
612636
613637
614638class ProcessGroupBaby (ProcessGroup ):
@@ -617,11 +641,8 @@ class ProcessGroupBaby(ProcessGroup):
617641 subprocess. Since it's running in a subprocess all tensors need to be in
618642 shared memory or will be moved to shared memory. CUDA tensors are implicitly
619643 share able and don't need any changes.
620-
621644 """
622645
623- WORK_CLASS : Type [_BabyWork ] = _BabyWork
624-
625646 def __init__ (self , timeout : Union [float , timedelta ] = 60.0 ) -> None :
626647 super ().__init__ (0 , 1 )
627648
@@ -679,7 +700,14 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None:
679700
680701 self ._p = ctx .Process (
681702 target = self ._worker ,
682- args = (store_addr , rank , world_size , self ._tx , self ._rx , self ._future_queue ),
703+ args = (
704+ store_addr ,
705+ rank ,
706+ world_size ,
707+ self ._tx ,
708+ self ._rx ,
709+ self ._future_queue ,
710+ ),
683711 daemon = True ,
684712 )
685713 self ._p .start ()
@@ -716,23 +744,70 @@ def _worker(
716744 return
717745 tx .put (None )
718746
719- work = {}
747+ streams : Dict [str , torch .cuda .Stream ] = {}
748+ work : Dict [int , _OpMetadata ] = {}
720749 next_op_id : int = 0
721750
722751 while True :
723752 op = rx .get ()
724753 cmd = op [0 ]
725754 if cmd == "func" :
726- func_name , args , kwargs = op [1 :]
727- args = _PickleSafeOptions .unsafe_args (args )
728- fn = getattr (pg , func_name )
729- work [next_op_id ] = fn (* args , ** kwargs )
755+ func_name , args , kwargs , stream_device , stream_id , event = op [1 :]
756+
757+ # To avoid potential deadlocks we need to preserve the
758+ # stream/synchronization behavior of the parent process.
759+ # We allocate one Stream per stream_id to make sure that we
760+ # don't accidentally introduce cross stream synchronization
761+ # points.
762+ if stream_id is not None :
763+ stream_key = f"{ stream_device } /{ stream_id } "
764+ if stream_key not in streams :
765+ streams [stream_key ] = torch .cuda .Stream (
766+ device = stream_device
767+ )
768+ stream = streams [stream_key ]
769+ else :
770+ stream = None
771+
772+ with (
773+ torch .cuda .stream (stream )
774+ if stream is not None
775+ else nullcontext ()
776+ ):
777+ # Make the stream wait on the cuda event to make sure we
778+ # don't start the operation until the tensor is ready.
779+ if event is not None :
780+ event .wait ()
781+
782+ args = _PickleSafeOptions .unsafe_args (args )
783+ fn = getattr (pg , func_name )
784+ work [next_op_id ] = _OpMetadata (
785+ work = fn (* args , ** kwargs ),
786+ stream = stream ,
787+ )
730788 tx .put (next_op_id )
731789 next_op_id += 1
732790 elif cmd == "wait" :
733791 op_id : int = op [1 ]
734- work [op_id ].wait ()
735- tx .put (op_id )
792+
793+ metadata = work [op_id ]
794+
795+ with metadata .set_stream ():
796+ # With WorkNCCL this makes the stream wait not the CPU when
797+ # no timeout is passed.
798+ metadata .work .wait ()
799+
800+ # Register event on the stream that we can pass to the main
801+ # process.
802+ event = (
803+ torch .cuda .current_stream ().record_event (
804+ torch .cuda .Event (interprocess = True )
805+ )
806+ if metadata .stream is not None
807+ else None
808+ )
809+
810+ tx .put ((op_id , event ))
736811 elif cmd == "del" :
737812 op_id : int = op [1 ]
738813 del work [op_id ]
@@ -746,23 +821,8 @@ def callback(fut: Future[object]) -> None:
746821 except Exception as e :
747822 future_queue .put ((op_id , _FUTURE_EXCEPTION , e ))
748823
749- work [op_id ].get_future ().add_done_callback (callback )
824+ work [op_id ].work . get_future ().add_done_callback (callback )
750825 tx .put (op_id )
751- elif cmd == "synchronize" :
752- # CUDA only, use events instead of waiting on CPU
753- op_id = op [1 ]
754-
755- # With WorkNCCL this makes the stream wait not the CPU when
756- # no timeout is passed.
757- work [op_id ].wait ()
758-
759- # Register event on the stream that we can pass to the main
760- # process.
761- event = torch .cuda .Event (interprocess = True )
762- event .record ()
763-
764- del work [op_id ]
765- tx .put ((op_id , event ))
766826 elif cmd == "num_active_work" :
767827 tx .put (len (work ))
768828 else :
@@ -771,6 +831,7 @@ def callback(fut: Future[object]) -> None:
771831 except Exception as e :
772832 logger .exception ("worker errored" )
773833 tx .put (e )
834+ raise
774835
775836 def _future_handler (self , future_queue : mp .Queue ) -> None :
776837 try :
@@ -792,6 +853,8 @@ def _future_handler(self, future_queue: mp.Queue) -> None:
792853 logger .exception (f"got unexpected error in future handler: { e } " )
793854
794855 def _get_future (self , op_id : int ) -> Future [object ]:
856+ self ._assert_alive ()
857+
795858 with self ._futures_lock :
796859 fut = Future () # pyre-fixme[29]: is not a function
797860 self ._futures [op_id ] = fut
@@ -804,22 +867,52 @@ def _get_future(self, op_id: int) -> Future[object]:
804867 return fut
805868
806869 def _run_func (self , func : str , * args : object , ** kwargs : object ) -> Work :
870+ self ._assert_alive ()
871+
807872 rx = self ._rx
808873 tx = self ._tx
809874 assert rx is not None
810875 assert tx is not None
811876
877+ is_cuda = _is_any_cuda (args )
878+
879+ stream_device = torch .cuda .current_stream ().device if is_cuda else None
880+ stream_id = torch .cuda .current_stream ().stream_id if is_cuda else None
881+ event = (
882+ torch .cuda .current_stream ().record_event (
883+ torch .cuda .Event (interprocess = True )
884+ )
885+ if is_cuda
886+ else None
887+ )
888+
812889 tx .put (
813- ("func" , func , _PickleSafeOptions .safe_args (args ), kwargs ),
890+ (
891+ "func" ,
892+ func ,
893+ _PickleSafeOptions .safe_args (args ),
894+ kwargs ,
895+ stream_device ,
896+ stream_id ,
897+ event ,
898+ ),
814899 timeout = self ._timeout ,
815900 )
816901
817902 op_id = _get (rx , self ._timeout )
818903 assert isinstance (op_id , int ), f"invalid return { op_id } "
819904
820- return self .WORK_CLASS (
821- pg = self , tx = tx , rx = rx , op_id = op_id , timeout = self ._timeout
822- )
905+ return _BabyWork (pg = self , tx = tx , rx = rx , op_id = op_id , timeout = self ._timeout )
906+
907+ def _assert_alive (self ) -> None :
908+ """
909+ Assert that the process group is alive. This is used to ensure that
910+ operations are not performed on a dead process group and any errors are surfaced.
911+ """
912+ p = self ._p
913+ assert p is not None
914+ if not p .is_alive ():
915+ raise RuntimeError (f"child process { p .pid = } is dead { p .exitcode = } " )
823916
824917 def allreduce (
825918 self ,
@@ -952,8 +1045,6 @@ class ProcessGroupBabyNCCL(ProcessGroupBaby):
9521045 tensors may leak in the current PyTorch implementation. TODO fix
9531046 """
9541047
955- WORK_CLASS = _BabyWorkNCCL
956-
9571048 @classmethod
9581049 def _create_pg (cls , store : Store , rank : int , world_size : int ) -> BaseProcessGroup :
9591050 # pyre-fixme[16]: no attribute ProcessGroupNCCL
0 commit comments