Skip to content

[Refactor]Abstract Platform Interface for Distributed Backend and Add xccl Support for Intel XPU #19410

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions vllm/platforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import TYPE_CHECKING, Optional

from vllm.plugins import load_plugins_by_group
from vllm.utils import resolve_obj_by_qualname
from vllm.utils import resolve_obj_by_qualname, supports_xccl

from .interface import _Backend # noqa: F401
from .interface import CpuArchEnum, Platform, PlatformEnum
Expand Down Expand Up @@ -138,11 +138,32 @@ def xpu_platform_plugin() -> Optional[str]:
logger.debug("Checking if XPU platform is available.")
try:
# installed IPEX if the machine has XPUs.
# detect dist_backend
import os

import intel_extension_for_pytorch # noqa: F401
import oneccl_bindings_for_pytorch # noqa: F401
import torch
if supports_xccl():
default_backend = "xccl"
else:
logger.debug("xccl is not available.")
default_backend = "ccl"
detect_backend = os.getenv("XPU_CCL_BACKEND", default_backend)

if detect_backend not in ["xccl", "ccl"]:
raise ValueError(
f"Unknown {detect_backend} backend for XPU platform.")

if detect_backend == "ccl":
logger.debug("Checking if ccl is available.")
import oneccl_bindings_for_pytorch # noqa: F401

if hasattr(torch, 'xpu') and torch.xpu.is_available():
is_xpu = True
from vllm.platforms.xpu import XPUPlatform
XPUPlatform.dist_backend = detect_backend
logger.debug("Confirmed %s backend is available.",
XPUPlatform.dist_backend)
logger.debug("Confirmed XPU platform is available.")
except Exception as e:
logger.debug("XPU platform is not available because: %s", str(e))
Expand Down
1 change: 1 addition & 0 deletions vllm/platforms/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class CpuPlatform(Platform):
device_name: str = "cpu"
device_type: str = "cpu"
dispatch_key: str = "CPU"
dist_backend: str = "gloo"

@property
def supported_dtypes(self) -> list[torch.dtype]:
Expand Down
1 change: 1 addition & 0 deletions vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class CudaPlatformBase(Platform):
device_type: str = "cuda"
dispatch_key: str = "CUDA"
ray_device_key: str = "GPU"
dist_backend: str = "nccl"
device_control_env_var: str = "CUDA_VISIBLE_DEVICES"

@property
Expand Down
1 change: 1 addition & 0 deletions vllm/platforms/hpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class HpuPlatform(Platform):
device_type: str = "hpu"
dispatch_key: str = "HPU"
ray_device_key: str = "HPU"
dist_backend: str = "hccl"
device_control_env_var: str = "HABANA_VISIBLE_MODULES"

@classmethod
Expand Down
3 changes: 3 additions & 0 deletions vllm/platforms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,9 @@ class Platform:
# compilation strategy.
simple_compile_backend: str = "inductor"

# The backend used for distributed communication.
dist_backend: str = ""

supported_quantization: list[str] = []

additional_env_vars: list[str] = []
Expand Down
1 change: 1 addition & 0 deletions vllm/platforms/neuron.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class NeuronPlatform(Platform):
device_type: str = "neuron"
ray_device_key: str = "neuron_cores"
supported_quantization: list[str] = ["neuron_quant", "fbgemm_fp8"]
dist_backend: str = "gloo"
device_control_env_var: str = "NEURON_RT_VISIBLE_CORES"

@classmethod
Expand Down
1 change: 1 addition & 0 deletions vllm/platforms/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ class RocmPlatform(Platform):
device_type: str = "cuda"
dispatch_key: str = "CUDA"
ray_device_key: str = "GPU"
dist_backend: str = "gloo"
# rocm shares the same device control env var as CUDA
device_control_env_var: str = "CUDA_VISIBLE_DEVICES"

Expand Down
1 change: 1 addition & 0 deletions vllm/platforms/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class TpuPlatform(Platform):
device_type: str = "tpu"
dispatch_key: str = "XLA"
ray_device_key: str = "TPU"
dist_backend: str = "gloo"
device_control_env_var: str = "TPU_VISIBLE_CHIPS"
simple_compile_backend: str = "openxla"

Expand Down
1 change: 1 addition & 0 deletions vllm/platforms/xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class XPUPlatform(Platform):
# Intel XPU's device key is "GPU" for Ray.
# see https://github.com/ray-project/ray/blob/6a5eb5865eeb9ccf058a79b44f107e327e360673/python/ray/_private/accelerators/intel_gpu.py#L20 # noqa: E501
ray_device_key: str = "GPU"
dist_backend: str = "ccl" # ccl | xccl
device_control_env_var: str = "ONEAPI_DEVICE_SELECTOR"

@classmethod
Expand Down
7 changes: 7 additions & 0 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1830,6 +1830,13 @@ def supports_dynamo() -> bool:
return base_torch_version >= Version("2.4.0")


# Supports xccl with PyTorch versions >= 2.8.0 for XPU platform
def supports_xccl() -> bool:
base_torch_version = Version(Version(torch.__version__).base_version)
return base_torch_version >= Version(
"2.8.0") and torch.distributed.is_xccl_available()


# Some backends use pytorch version < 2.4.0 which doesn't
# support `torch.library.custom_op`.
def supports_custom_op() -> bool:
Expand Down
4 changes: 3 additions & 1 deletion vllm/v1/worker/cpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
from vllm.logger import init_logger
from vllm.model_executor.utils import set_random_seed
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import ModelRunnerOutput
Expand Down Expand Up @@ -57,7 +58,8 @@ def init_device(self):
# Initialize the distributed environment.
init_worker_distributed_environment(self.vllm_config, self.rank,
self.distributed_init_method,
self.local_rank, "gloo")
self.local_rank,
current_platform.dist_backend)
# Set random seed.
set_random_seed(self.model_config.seed)

Expand Down
3 changes: 2 additions & 1 deletion vllm/v1/worker/gpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,8 @@ def init_device(self):
# Initialize the distributed environment.
init_worker_distributed_environment(self.vllm_config, self.rank,
self.distributed_init_method,
self.local_rank)
self.local_rank,
current_platform.dist_backend)
# Set random seed.
set_random_seed(self.model_config.seed)

Expand Down
3 changes: 2 additions & 1 deletion vllm/v1/worker/tpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed
from vllm.platforms import current_platform
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import (AttentionSpec, KVCacheConfig,
Expand Down Expand Up @@ -284,7 +285,7 @@ def _init_tpu_worker_distributed_environment(
rank=rank,
local_rank=local_rank,
distributed_init_method=distributed_init_method,
backend="gloo",
backend=current_platform.dist_backend,
)
ensure_model_parallel_initialized(
parallel_config.tensor_parallel_size,
Expand Down
3 changes: 2 additions & 1 deletion vllm/worker/cpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed
from vllm.platforms import current_platform
from vllm.sequence import ExecuteModelRequest
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, bind_kv_cache
from vllm.worker.cpu_enc_dec_model_runner import CPUEncoderDecoderModelRunner
Expand Down Expand Up @@ -386,7 +387,7 @@ def init_distributed_environment(self) -> None:
world_size=parallel_config.world_size,
rank=rank,
distributed_init_method=distributed_init_method,
backend="gloo",
backend=current_platform.dist_backend,
)

# A small all_reduce for warmup.
Expand Down
3 changes: 2 additions & 1 deletion vllm/worker/hpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.platforms import current_platform
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import ExecuteModelRequest
from vllm.utils import bind_kv_cache
Expand Down Expand Up @@ -413,7 +414,7 @@ def init_worker_distributed_environment(
rank,
distributed_init_method,
local_rank,
backend='hccl')
backend=current_platform.dist_backend)

ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)
Expand Down
2 changes: 1 addition & 1 deletion vllm/worker/neuron_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def init_distributed_environment(self):
rank=self.rank,
local_rank=self.local_rank,
distributed_init_method=self.distributed_init_method,
backend="gloo",
backend=current_platform.dist_backend,
)

ensure_model_parallel_initialized(
Expand Down
3 changes: 2 additions & 1 deletion vllm/worker/tpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
init_distributed_environment)
from vllm.logger import init_logger
from vllm.model_executor import set_random_seed
from vllm.platforms import current_platform
from vllm.sequence import ExecuteModelRequest
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, bind_kv_cache, get_dtype_size
from vllm.worker.tpu_model_runner import ExecutionMode, TPUModelRunner
Expand Down Expand Up @@ -73,7 +74,7 @@ def init_device(self) -> None:
rank=self.rank,
local_rank=self.local_rank,
distributed_init_method=self.distributed_init_method,
backend="gloo",
backend=current_platform.dist_backend,
)
ensure_model_parallel_initialized(
self.parallel_config.tensor_parallel_size,
Expand Down
3 changes: 2 additions & 1 deletion vllm/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,8 @@ def init_worker_distributed_environment(
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)

init_distributed_environment(parallel_config.world_size, rank,
distributed_init_method, local_rank)
distributed_init_method, local_rank,
current_platform.dist_backend)
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)

Expand Down
4 changes: 1 addition & 3 deletions vllm/worker/xpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
import os
from typing import List, Optional, Tuple

import intel_extension_for_pytorch # noqa: F401
import oneccl_bindings_for_pytorch # noqa: F401
import torch
import torch.distributed

Expand Down Expand Up @@ -172,7 +170,7 @@ def init_worker_distributed_environment(self) -> None:
rank=rank,
distributed_init_method=distributed_init_method,
local_rank=self.local_rank,
backend="ccl")
backend=current_platform.dist_backend)

ensure_model_parallel_initialized(
parallel_config.tensor_parallel_size,
Expand Down