diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index f6a2fc9b05a84..16c5297af1b53 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -23,8 +23,9 @@ from collections import namedtuple from contextlib import contextmanager, nullcontext from dataclasses import dataclass -from multiprocessing import resource_tracker, shared_memory +from multiprocessing import shared_memory from typing import Any, Dict, List, Optional, Tuple, Union +from unittest.mock import patch import torch from torch.distributed import Backend, ProcessGroup @@ -744,7 +745,12 @@ def is_in_the_same_node(pg: ProcessGroup): src=ranks[0], group=pg) name = recv[0] - shm = shared_memory.SharedMemory(name=name) + # fix to https://stackoverflow.com/q/62748654/9191338 + # Python incorrectly tracks shared memory even if it is not + # created by the process. The following patch is a workaround. + with patch("multiprocessing.resource_tracker.register", + lambda *args, **kwargs: None): + shm = shared_memory.SharedMemory(name=name) if shm.buf[:len(magic_message)] == magic_message: is_in_the_same_node[rank] = 1 except Exception as e: @@ -757,14 +763,8 @@ def is_in_the_same_node(pg: ProcessGroup): # clean up the shared memory segment with contextlib.suppress(OSError): - if rank == 0: - if shm: - shm.unlink() - else: - if shm: - # fix to https://stackoverflow.com/q/62748654/9191338 - resource_tracker.unregister( - shm._name, "shared_memory") # type: ignore[attr-defined] + if rank == 0 and shm: + shm.unlink() torch.distributed.all_reduce(is_in_the_same_node, group=pg) return is_in_the_same_node.sum().item() == world_size