diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index df07842edfa56..d3ac4eb78b155 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -36,6 +36,7 @@ import vllm.envs as envs from vllm.logger import init_logger from vllm.platforms import current_platform +from vllm.utils import supports_custom_op @dataclass @@ -95,32 +96,33 @@ def _register_group(group: "GroupCoordinator") -> None: _groups[group.unique_name] = weakref.ref(group) # type: ignore -@torch.library.custom_op("vllm::inplace_all_reduce", mutates_args=["tensor"]) -def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None: - assert group_name in _groups, f"Group {group_name} is not found." - group = _groups[group_name]() - if group is None: - raise ValueError(f"Group {group_name} is destroyed.") - group._all_reduce(tensor) +if supports_custom_op(): + @torch.library.custom_op("vllm::inplace_all_reduce", + mutates_args=["tensor"]) + def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None: + assert group_name in _groups, f"Group {group_name} is not found." + group = _groups[group_name]() + if group is None: + raise ValueError(f"Group {group_name} is destroyed.") + group._all_reduce(tensor) -@inplace_all_reduce.register_fake -def _(tensor: torch.Tensor, group_name: str) -> None: - return - - -@torch.library.custom_op("vllm::outplace_all_reduce", mutates_args=[]) -def outplace_all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor: - assert group_name in _groups, f"Group {group_name} is not found." - group = _groups[group_name]() - if group is None: - raise ValueError(f"Group {group_name} is destroyed.") - return group._all_reduce(tensor) + @inplace_all_reduce.register_fake + def _(tensor: torch.Tensor, group_name: str) -> None: + return + @torch.library.custom_op("vllm::outplace_all_reduce", mutates_args=[]) + def outplace_all_reduce(tensor: torch.Tensor, + group_name: str) -> torch.Tensor: + assert group_name in _groups, f"Group {group_name} is not found." + group = _groups[group_name]() + if group is None: + raise ValueError(f"Group {group_name} is destroyed.") + return group._all_reduce(tensor) -@outplace_all_reduce.register_fake -def _(tensor: torch.Tensor, group_name: str) -> torch.Tensor: - return torch.empty_like(tensor) + @outplace_all_reduce.register_fake + def _(tensor: torch.Tensor, group_name: str) -> torch.Tensor: + return torch.empty_like(tensor) class GroupCoordinator: @@ -335,6 +337,9 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: if self.world_size == 1: return input_ + if not supports_custom_op(): + return self._all_reduce(input_) + if self.tpu_communicator is not None and \ not self.tpu_communicator.disabled: # TPU handles Dynamo with its own logic. diff --git a/vllm/utils.py b/vllm/utils.py index 34fcfd5e44ed7..b1513b91a06c6 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1245,6 +1245,12 @@ def supports_dynamo() -> bool: return base_torch_version >= Version("2.4.0") +# Some backends use pytorch version < 2.4.0 which doesn't +# support `torch.library.custom_op`. +def supports_custom_op() -> bool: + return hasattr(torch.library, "custom_op") + + class AtomicCounter: """An atomic, thread-safe counter"""