2323from datetime import timedelta
2424from multiprocessing .connection import Connection
2525from typing import (
26- TYPE_CHECKING ,
2726 Any ,
2827 Callable ,
28+ cast ,
2929 Dict ,
3030 Generator ,
3131 List ,
3232 Optional ,
3333 Tuple ,
34+ TYPE_CHECKING ,
3435 TypeVar ,
3536 Union ,
36- cast ,
3737)
3838
3939import torch
4444# pyre-fixme[21]: no attribute ProcessGroupGloo
4545from torch .distributed import (
4646 DeviceMesh ,
47+ get_rank ,
48+ init_device_mesh ,
4749 PrefixStore ,
4850 ProcessGroup as BaseProcessGroup ,
4951 ProcessGroupGloo as BaseProcessGroupGloo ,
5052 ProcessGroupNCCL as BaseProcessGroupNCCL ,
5153 Store ,
5254 TCPStore ,
53- get_rank ,
54- init_device_mesh ,
5555)
5656from torch .distributed .distributed_c10d import (
5757 AllgatherOptions ,
@@ -970,6 +970,26 @@ def shutdown(self) -> None:
970970 self ._p .kill ()
971971
972972 def configure (self , store_addr : str , rank : int , world_size : int ) -> None :
973+ """
974+ Structure
975+ +-------------------+
976+ | |
977+ | Main Process | (updates self._futures)
978+ | | <---------------
979+ +-------------------+ |
980+ | Pipe 1 |
981+ v |
982+ +-------------------+ +-------------------+
983+ | | | |
984+ | Worker Process | -> | Future Thread |
985+ | | Pipe 2 | |
986+ +-------------------+ +-------------------+
987+
988+ Main Process: Central controller, maintains self._futures.
989+ Worker Process: Handles tasks, communicates with Future Thread.
990+ Future Thread: Manages asynchronous tasks, updates self._futures.
991+ """
992+
973993 self ._world_size = world_size
974994
975995 self .shutdown ()
@@ -990,7 +1010,7 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None:
9901010 rank ,
9911011 world_size ,
9921012 req_remote ,
993- future_remote ,
1013+ future_local ,
9941014 curr_device ,
9951015 ),
9961016 daemon = True ,
@@ -1003,7 +1023,7 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None:
10031023 self ._futures = {}
10041024 self ._future_thread = threading .Thread (
10051025 target = self ._future_handler ,
1006- args = (future_local ,),
1026+ args = (future_remote ,),
10071027 daemon = True ,
10081028 )
10091029 self ._future_thread .start ()
@@ -1163,11 +1183,14 @@ def callback(fut: Future[object], metadata: _OpMetadata) -> None:
11631183
11641184 def _future_handler (self , future_pipe : _MonitoredPipe ) -> None :
11651185 try :
1166- while True :
1186+ while not self . _future_thread_shutdown_flag . is_set () :
11671187 try :
11681188 cmd = future_pipe .recv (timedelta (seconds = 10 ))
11691189 except TimeoutError :
11701190 continue
1191+ # except EOFError:
1192+ # # Pipe was closed, exit the loop
1193+ # break
11711194
11721195 op_id , mode , data , event = cast (
11731196 Tuple [int , str , object , Optional [torch .cuda .Event ]], cmd
0 commit comments