Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions docs/configuration/optimization.md
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,51 @@ Data parallelism replicates the entire model across multiple GPU sets and proces
Data parallelism can be combined with the other parallelism strategies and is set by `data_parallel_size=N`.
Note that MoE layers will be sharded according to the product of the tensor parallel size and data parallel size.

### Batch-level DP for Multi-Modal Encoders

By default, TP is used to shard the weights of multi-modal encoders just like for language decoders,
in order to reduce the memory and compute load on each GPU.

However, since the size of multi-modal encoders is very small compared to language decoders,
there is relatively little gain from TP. On the other hand, TP incurs significant communication
overhead because of all-reduce being performed after every layer.

Given this, it may be advantageous to instead shard the batched input data using TP, essentially
performing batch-level DP. This has been shown to improve the throughput by around 10% for
`tensor_parallel_size=8`. For vision encoders that use hardware-unoptimized Conv3D operations,
batch-level DP can provide another 40% increase to throughput compared to regular TP.

Nevertheless, since the weights of the multi-modal encoder are replicated across each TP rank,
there will be a minor increase in memory consumption and may cause OOM if you can barely fit the model already.

You can enable batch-level DP by setting `mm_encoder_tp_mode="data"`, for example:

```python
from vllm import LLM

llm = LLM(
model="Qwen/Qwen2.5-VL-72B-Instruct",
# Create two EngineCore instances, one per DP rank
data_parallel_size=2,
# Within each EngineCore instance:
# The vision encoder uses TP=4 (not DP=2) to shard the input data
# The language decoder uses TP=4 to shard the weights as usual
tensor_parallel_size=4,
mm_encoder_tp_mode="data",
)
```

!! important
Batch-level DP is not to be confused with API request-level DP
(which is instead controlled by `data_parallel_size`).

The availablilty of batch-level DP is based on model implementation.
Currently, the following models support `mm_encoder_tp_mode="data"`:

- Llama4 (<gh-pr:18368>)
- Qwen2.5-VL (<gh-pr:22742>)
- Step3 (<gh-pr:22697>)

## Input Processing

### Parallel Processing
Expand Down
34 changes: 33 additions & 1 deletion vllm/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ def is_init_field(cls: ConfigType, name: str) -> bool:
ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"]
LogprobsMode = Literal["raw_logprobs", "raw_logits", "processed_logprobs",
"processed_logits"]
MMEncoderTPMode = Literal["weights", "data"]


@config
Expand Down Expand Up @@ -428,6 +429,19 @@ class ModelConfig:
`mm_processor_cache_gb * (api_server_count + data_parallel_size)`.

Set to `0` to disable this cache completely (not recommended)."""
mm_encoder_tp_mode: MMEncoderTPMode = "weights"
"""Indicates how to optimize multi-modal encoder inference using
tensor parallelism (TP).

- `"weights"`: Within the same vLLM engine, split the weights of
each layer across TP ranks. (default TP behavior)
- `"data"`: Within the same vLLM engine, split the batched input data
across TP ranks to process the data in parallel, while hosting
the full weights on each TP rank.
This batch-level DP is not to be confused with API request-level
DP (which is controlled by `--data-parallel-size`).
This is only supported on a per-model basis and falls back to
`"weights"` if the encoder does not support DP."""
override_neuron_config: dict[str, Any] = field(default_factory=dict)
"""Initialize non-default neuron config or override default neuron config
that are specific to Neuron devices, this argument will be used to
Expand Down Expand Up @@ -846,8 +860,10 @@ def _init_multimodal_config(self) -> Optional["MultiModalConfig"]:
media_io_kwargs=self.media_io_kwargs,
mm_processor_kwargs=self.mm_processor_kwargs,
mm_processor_cache_gb=self.mm_processor_cache_gb,
mm_encoder_tp_mode=self.mm_encoder_tp_mode,
interleave_mm_strings=self.interleave_mm_strings,
skip_mm_profiling=self.skip_mm_profiling)
skip_mm_profiling=self.skip_mm_profiling,
)

return None

Expand Down Expand Up @@ -2516,6 +2532,22 @@ class MultiModalConfig:
Set to `0` to disable this cache completely (not recommended).
"""

mm_encoder_tp_mode: MMEncoderTPMode = "weights"
"""
Indicates how to optimize multi-modal encoder inference using
tensor parallelism (TP).

- `"weights"`: Within the same vLLM engine, split the weights of
each layer across TP ranks. (default TP behavior)
- `"data"`: Within the same vLLM engine, split the batched input data
across TP ranks to process the data in parallel, while hosting
the full weights on each TP rank.
This batch-level DP is not to be confused with API request-level
DP (which is controlled by `--data-parallel-size`).
This is only supported on a per-model basis and falls back to
`"weights"` if the encoder does not support DP.
"""

interleave_mm_strings: bool = False
"""
Enable fully interleaved support for multimodal prompts.
Expand Down
4 changes: 0 additions & 4 deletions vllm/config/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,6 @@ class is dynamically inherited by the worker class. This is used to inject
rank: int = 0
"""Global rank in distributed setup."""

enable_multimodal_encoder_data_parallel: bool = False
""" Use data parallelism instead of tensor parallelism for vision encoder.
Only support LLama4 for now"""

@property
def world_size_across_dp(self) -> int:
"""world_size_across_dp is TPxPPxDP, it is the size of the world
Expand Down
35 changes: 22 additions & 13 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@
DeviceConfig, DistributedExecutorBackend,
GuidedDecodingBackend, HfOverrides, KVEventsConfig,
KVTransferConfig, LoadConfig, LogprobsMode,
LoRAConfig, MambaDType, ModelConfig, ModelDType,
ModelImpl, MultiModalConfig, ObservabilityConfig,
ParallelConfig, PoolerConfig, PrefixCachingHashAlgo,
RunnerOption, SchedulerConfig, SchedulerPolicy,
SpeculativeConfig, TaskOption, TokenizerMode,
VllmConfig, get_attr_docs, get_field)
LoRAConfig, MambaDType, MMEncoderTPMode, ModelConfig,
ModelDType, ModelImpl, MultiModalConfig,
ObservabilityConfig, ParallelConfig, PoolerConfig,
PrefixCachingHashAlgo, RunnerOption, SchedulerConfig,
SchedulerPolicy, SpeculativeConfig, TaskOption,
TokenizerMode, VllmConfig, get_attr_docs, get_field)
from vllm.logger import init_logger
from vllm.platforms import CpuArchEnum, current_platform
from vllm.plugins import load_general_plugins
Expand Down Expand Up @@ -351,6 +351,7 @@ class EngineArgs:
MultiModalConfig.mm_processor_kwargs
disable_mm_preprocessor_cache: bool = False # DEPRECATED
mm_processor_cache_gb: int = MultiModalConfig.mm_processor_cache_gb
mm_encoder_tp_mode: MMEncoderTPMode = MultiModalConfig.mm_encoder_tp_mode
skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling
# LoRA fields
enable_lora: bool = False
Expand Down Expand Up @@ -433,16 +434,14 @@ class EngineArgs:
use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load
pt_load_map_location: str = LoadConfig.pt_load_map_location

enable_multimodal_encoder_data_parallel: bool = \
ParallelConfig.enable_multimodal_encoder_data_parallel
# DEPRECATED
enable_multimodal_encoder_data_parallel: bool = False

logits_processors: Optional[list[Union[
str, type[LogitsProcessor]]]] = ModelConfig.logits_processors
"""Custom logitproc types"""

async_scheduling: bool = SchedulerConfig.async_scheduling
# DEPRECATED
enable_prompt_adapter: bool = False
Comment on lines -444 to -445
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prompt adapters have been fully removed so I removed this unused code along the way.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, thanks for fixing it


kv_sharing_fast_prefill: bool = \
CacheConfig.kv_sharing_fast_prefill
Expand Down Expand Up @@ -677,7 +676,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
**parallel_kwargs["worker_extension_cls"])
parallel_group.add_argument(
"--enable-multimodal-encoder-data-parallel",
**parallel_kwargs["enable_multimodal_encoder_data_parallel"])
action="store_true",
deprecated=True)

# KV cache arguments
cache_kwargs = get_kwargs(CacheConfig)
Expand Down Expand Up @@ -727,6 +727,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
multimodal_group.add_argument("--disable-mm-preprocessor-cache",
action="store_true",
deprecated=True)
multimodal_group.add_argument(
"--mm-encoder-tp-mode", **multimodal_kwargs["mm_encoder_tp_mode"])
multimodal_group.add_argument(
"--interleave-mm-strings",
**multimodal_kwargs["interleave_mm_strings"])
Expand Down Expand Up @@ -901,6 +903,14 @@ def create_model_config(self) -> ModelConfig:

self.mm_processor_cache_gb = envs.VLLM_MM_INPUT_CACHE_GIB

if self.enable_multimodal_encoder_data_parallel:
logger.warning(
"--enable-multimodal-encoder-data-parallel` is deprecated "
"and will be removed in v0.13. "
"Please use `--mm-encoder-tp-mode data` instead.")

self.mm_encoder_tp_mode = "data"

return ModelConfig(
model=self.model,
hf_config_path=self.hf_config_path,
Expand Down Expand Up @@ -939,6 +949,7 @@ def create_model_config(self) -> ModelConfig:
config_format=self.config_format,
mm_processor_kwargs=self.mm_processor_kwargs,
mm_processor_cache_gb=self.mm_processor_cache_gb,
mm_encoder_tp_mode=self.mm_encoder_tp_mode,
override_neuron_config=self.override_neuron_config,
override_pooler_config=self.override_pooler_config,
logits_processor_pattern=self.logits_processor_pattern,
Expand Down Expand Up @@ -1250,8 +1261,6 @@ def create_engine_config(
distributed_executor_backend=self.distributed_executor_backend,
worker_cls=self.worker_cls,
worker_extension_cls=self.worker_extension_cls,
enable_multimodal_encoder_data_parallel=self.
enable_multimodal_encoder_data_parallel,
)

if model_config.is_multimodal_model:
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/models/mllama4.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,8 +728,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
multimodal_config = vllm_config.model_config.multimodal_config
self.use_data_parallel = (vllm_config.parallel_config.
enable_multimodal_encoder_data_parallel)
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"

self.config = config
self.quant_config = quant_config
self.multimodal_config = multimodal_config
Expand Down
3 changes: 1 addition & 2 deletions vllm/model_executor/models/qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,8 +877,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config: Qwen2_5_VLConfig = vllm_config.model_config.hf_config
multimodal_config = vllm_config.model_config.multimodal_config

self.use_data_parallel = (vllm_config.parallel_config.
enable_multimodal_encoder_data_parallel)
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
self.config = config
self.multimodal_config = multimodal_config

Expand Down
3 changes: 1 addition & 2 deletions vllm/model_executor/models/step3_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,8 +882,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:

self.config = config
self.multimodal_config = multimodal_config
self.use_data_parallel = (vllm_config.parallel_config.
enable_multimodal_encoder_data_parallel)
self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"

if multimodal_config.get_limit_per_prompt("image"):
self.vision_model = Step3VisionTransformer(
Expand Down