2323from collections import namedtuple
2424from contextlib import contextmanager , nullcontext
2525from dataclasses import dataclass
26- from multiprocessing import resource_tracker , shared_memory
26+ from multiprocessing import shared_memory
2727from typing import Any , Dict , List , Optional , Tuple , Union
28+ from unittest .mock import patch
2829
2930import torch
3031from torch .distributed import Backend , ProcessGroup
@@ -744,7 +745,12 @@ def is_in_the_same_node(pg: ProcessGroup):
744745 src = ranks [0 ],
745746 group = pg )
746747 name = recv [0 ]
747- shm = shared_memory .SharedMemory (name = name )
748+ # fix to https://stackoverflow.com/q/62748654/9191338
749+ # Python incorrectly tracks shared memory even if it is not
750+ # created by the process. The following patch is a workaround.
751+ with patch ("multiprocessing.resource_tracker.register" ,
752+ lambda * args , ** kwargs : None ):
753+ shm = shared_memory .SharedMemory (name = name )
748754 if shm .buf [:len (magic_message )] == magic_message :
749755 is_in_the_same_node [rank ] = 1
750756 except Exception as e :
@@ -757,14 +763,8 @@ def is_in_the_same_node(pg: ProcessGroup):
757763
758764 # clean up the shared memory segment
759765 with contextlib .suppress (OSError ):
760- if rank == 0 :
761- if shm :
762- shm .unlink ()
763- else :
764- if shm :
765- # fix to https://stackoverflow.com/q/62748654/9191338
766- resource_tracker .unregister (
767- shm ._name , "shared_memory" ) # type: ignore[attr-defined]
766+ if rank == 0 and shm :
767+ shm .unlink ()
768768 torch .distributed .all_reduce (is_in_the_same_node , group = pg )
769769
770770 return is_in_the_same_node .sum ().item () == world_size
0 commit comments