Skip to content

[Platform] Move platform check to right place #18470

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 10 additions & 25 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,10 @@
try_get_generation_config, uses_mrope)
from vllm.transformers_utils.s3_utils import S3Model
from vllm.transformers_utils.utils import is_s3, maybe_model_redirect
from vllm.utils import (GiB_bytes, LayerBlockType, cuda_device_count_stateless,
from vllm.utils import (DEFAULT_MAX_NUM_BATCHED_TOKENS,
MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS,
POOLING_MODEL_MAX_NUM_BATCHED_TOKENS, GiB_bytes,
LayerBlockType, cuda_device_count_stateless,
get_cpu_memory, get_open_port, is_torch_equal_or_newer,
random_uuid, resolve_obj_by_qualname)

Expand All @@ -64,12 +67,6 @@

ConfigT = TypeVar("ConfigT", bound=ConfigType)

# This value is chosen to have a balance between ITL and TTFT. Note it is
# not optimized for throughput.
_DEFAULT_MAX_NUM_BATCHED_TOKENS = 2048
_POOLING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120

TaskOption = Literal["auto", "generate", "embedding", "embed", "classify",
"score", "reward", "transcription"]

Expand Down Expand Up @@ -2074,28 +2071,28 @@ def __post_init__(self) -> None:
# so we don't reject sequences on account of a short
# max_num_batched_tokens.
self.max_num_batched_tokens = max(
self.max_model_len, _DEFAULT_MAX_NUM_BATCHED_TOKENS)
self.max_model_len, DEFAULT_MAX_NUM_BATCHED_TOKENS)
else:
self.max_num_batched_tokens = (
_DEFAULT_MAX_NUM_BATCHED_TOKENS)
DEFAULT_MAX_NUM_BATCHED_TOKENS)
else:
# If max_model_len is too short, use
# _DEFAULT_MAX_NUM_BATCHED_TOKENS as the default value
# DEFAULT_MAX_NUM_BATCHED_TOKENS as the default value
# for higher throughput.
self.max_num_batched_tokens = max(
self.max_model_len, _DEFAULT_MAX_NUM_BATCHED_TOKENS)
self.max_model_len, DEFAULT_MAX_NUM_BATCHED_TOKENS)

if self.runner_type == "pooling":
# Choose specific value for higher throughput
self.max_num_batched_tokens = max(
self.max_num_batched_tokens,
_POOLING_MODEL_MAX_NUM_BATCHED_TOKENS,
POOLING_MODEL_MAX_NUM_BATCHED_TOKENS,
)
if self.is_multimodal_model:
# The value needs to be at least the number of multimodal tokens
self.max_num_batched_tokens = max(
self.max_num_batched_tokens,
_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS,
MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS,
)

# When using default settings,
Expand Down Expand Up @@ -4316,18 +4313,6 @@ def __post_init__(self):
"full_cuda_graph is not supported with "
"cascade attention. Disabling cascade attention.")
self.model_config.disable_cascade_attn = True

if self.model_config and self.model_config.use_mla and \
not (current_platform.is_cuda() or current_platform.is_rocm()):
logger.info(
"MLA is enabled on a non-GPU platform; forcing chunked "
"prefill and prefix caching to be disabled.")
self.scheduler_config.enable_chunked_prefill = False
self.scheduler_config.chunked_prefill_enabled = False
self.scheduler_config.max_num_batched_tokens = max(
self.scheduler_config.max_model_len,
_DEFAULT_MAX_NUM_BATCHED_TOKENS)

if self.cache_config is not None:
self.cache_config.enable_prefix_caching = False

Expand Down
11 changes: 11 additions & 0 deletions vllm/platforms/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch

from vllm.logger import init_logger
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS

from .interface import CpuArchEnum, Platform, PlatformEnum, _Backend

Expand Down Expand Up @@ -177,6 +178,16 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
" set VLLM_WORKER_MULTIPROC_METHOD to fork explicitly.")
os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'

if vllm_config.model_config and vllm_config.model_config.use_mla:
logger.info(
"MLA is enabled on a non-GPU platform; forcing chunked "
"prefill and prefix caching to be disabled.")
vllm_config.scheduler_config.enable_chunked_prefill = False
vllm_config.scheduler_config.chunked_prefill_enabled = False
vllm_config.scheduler_config.max_num_batched_tokens = max(
vllm_config.scheduler_config.max_model_len,
DEFAULT_MAX_NUM_BATCHED_TOKENS)

@classmethod
def is_pin_memory_available(cls) -> bool:
logger.warning("Pin memory is not supported on CPU.")
Expand Down
11 changes: 11 additions & 0 deletions vllm/platforms/hpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from vllm import envs
from vllm.logger import init_logger
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS

from .interface import Platform, PlatformEnum, _Backend

Expand Down Expand Up @@ -80,6 +81,16 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
"VLLM_WORKER_MULTIPROC_METHOD=fork explicitly.")
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"

if vllm_config.model_config and vllm_config.model_config.use_mla:
logger.info(
"MLA is enabled on a non-GPU platform; forcing chunked "
"prefill and prefix caching to be disabled.")
vllm_config.scheduler_config.enable_chunked_prefill = False
vllm_config.scheduler_config.chunked_prefill_enabled = False
vllm_config.scheduler_config.max_num_batched_tokens = max(
vllm_config.scheduler_config.max_model_len,
DEFAULT_MAX_NUM_BATCHED_TOKENS)

@classmethod
def is_pin_memory_available(cls):
logger.warning("Pin memory is not supported on HPU.")
Expand Down
11 changes: 11 additions & 0 deletions vllm/platforms/neuron.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from vllm import envs
from vllm.logger import init_logger
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS

from .interface import Platform, PlatformEnum

Expand Down Expand Up @@ -56,6 +57,16 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
vllm_config.cache_config.block_size = \
vllm_config.model_config.max_model_len # type: ignore

if vllm_config.model_config and vllm_config.model_config.use_mla:
logger.info(
"MLA is enabled on a non-GPU platform; forcing chunked "
"prefill and prefix caching to be disabled.")
vllm_config.scheduler_config.enable_chunked_prefill = False
vllm_config.scheduler_config.chunked_prefill_enabled = False
vllm_config.scheduler_config.max_num_batched_tokens = max(
vllm_config.scheduler_config.max_model_len,
DEFAULT_MAX_NUM_BATCHED_TOKENS)

@classmethod
def is_pin_memory_available(cls) -> bool:
logger.warning("Pin memory is not supported on Neuron.")
Expand Down
11 changes: 11 additions & 0 deletions vllm/platforms/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from vllm.inputs import ProcessorInputs, PromptType
from vllm.logger import init_logger
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS

from .interface import Platform, PlatformEnum, _Backend

Expand Down Expand Up @@ -161,6 +162,16 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
"Forcing --disable_chunked_mm_input.")
scheduler_config.disable_chunked_mm_input = True

if vllm_config.model_config and vllm_config.model_config.use_mla:
logger.info(
"MLA is enabled on a non-GPU platform; forcing chunked "
"prefill and prefix caching to be disabled.")
vllm_config.scheduler_config.enable_chunked_prefill = False
vllm_config.scheduler_config.chunked_prefill_enabled = False
vllm_config.scheduler_config.max_num_batched_tokens = max(
vllm_config.scheduler_config.max_model_len,
DEFAULT_MAX_NUM_BATCHED_TOKENS)

@classmethod
def is_pin_memory_available(cls):
logger.warning("Pin memory is not supported on TPU.")
Expand Down
11 changes: 11 additions & 0 deletions vllm/platforms/xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch

from vllm.logger import init_logger
from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS

from .interface import DeviceCapability, Platform, PlatformEnum, _Backend

Expand Down Expand Up @@ -113,6 +114,16 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
parallel_config.distributed_executor_backend)
parallel_config.distributed_executor_backend = "ray"

if vllm_config.model_config and vllm_config.model_config.use_mla:
logger.info(
"MLA is enabled on a non-GPU platform; forcing chunked "
"prefill and prefix caching to be disabled.")
vllm_config.scheduler_config.enable_chunked_prefill = False
vllm_config.scheduler_config.chunked_prefill_enabled = False
vllm_config.scheduler_config.max_num_batched_tokens = max(
vllm_config.scheduler_config.max_model_len,
DEFAULT_MAX_NUM_BATCHED_TOKENS)

@classmethod
def is_pin_memory_available(cls):
logger.warning("Pin memory is not supported on XPU.")
Expand Down
6 changes: 6 additions & 0 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,12 @@

logger = init_logger(__name__)

# This value is chosen to have a balance between ITL and TTFT. Note it is
# not optimized for throughput.
DEFAULT_MAX_NUM_BATCHED_TOKENS = 2048
POOLING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120

# Exception strings for non-implemented encoder/decoder scenarios

# Reminder: Please update docs/source/features/compatibility_matrix.md
Expand Down