23
23
from collections import namedtuple
24
24
from contextlib import contextmanager , nullcontext
25
25
from dataclasses import dataclass
26
- from multiprocessing import resource_tracker , shared_memory
26
+ from multiprocessing import shared_memory
27
27
from typing import Any , Dict , List , Optional , Tuple , Union
28
+ from unittest .mock import patch
28
29
29
30
import torch
30
31
from torch .distributed import Backend , ProcessGroup
@@ -744,7 +745,12 @@ def is_in_the_same_node(pg: ProcessGroup):
744
745
src = ranks [0 ],
745
746
group = pg )
746
747
name = recv [0 ]
747
- shm = shared_memory .SharedMemory (name = name )
748
+ # fix to https://stackoverflow.com/q/62748654/9191338
749
+ # Python incorrectly tracks shared memory even if it is not
750
+ # created by the process. The following patch is a workaround.
751
+ with patch ("multiprocessing.resource_tracker.register" ,
752
+ lambda * args , ** kwargs : None ):
753
+ shm = shared_memory .SharedMemory (name = name )
748
754
if shm .buf [:len (magic_message )] == magic_message :
749
755
is_in_the_same_node [rank ] = 1
750
756
except Exception as e :
@@ -757,14 +763,8 @@ def is_in_the_same_node(pg: ProcessGroup):
757
763
758
764
# clean up the shared memory segment
759
765
with contextlib .suppress (OSError ):
760
- if rank == 0 :
761
- if shm :
762
- shm .unlink ()
763
- else :
764
- if shm :
765
- # fix to https://stackoverflow.com/q/62748654/9191338
766
- resource_tracker .unregister (
767
- shm ._name , "shared_memory" ) # type: ignore[attr-defined]
766
+ if rank == 0 and shm :
767
+ shm .unlink ()
768
768
torch .distributed .all_reduce (is_in_the_same_node , group = pg )
769
769
770
770
return is_in_the_same_node .sum ().item () == world_size
0 commit comments