Skip to content

[Model] enable data parallel for Llama4 vision encoder #18368

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 2, 2025

Conversation

jennyyyyzhen
Copy link
Contributor

@jennyyyyzhen jennyyyyzhen commented May 19, 2025

Summary:
Llama4 vision encoder in dp8 is ~3x as fast as in tp8, especially when handling a large number of input images (eg. 9 images per request).
Add an enable_vision_encoder_data_parallel to allow using different parallelism for vision model and language model.

Unit test
pytest tests/models/multimodal/generation/test_common.py -k "llama4"

MM Eval
Baseline TP8

lm_eval --model vllm-vlm --model_args pretrained=meta-llama/Llama-4-Scout-17B-16E-Instruct,tensor_parallel_size=8,max_model_len=32768,gpu_memory_utilization=0.9 --tasks chartqa --batch_size auto --apply_chat_template
Tasks Version Filter n-shot Metric Value Stderr
chartqa 0 none 0 anywhere_accuracy 0.8872 ± 0.0063
none 0 exact_match 0.6548 ± 0.0095
none 0 relaxed_accuracy 0.8848 ± 0.0064

DP8

lm_eval --model vllm-vlm --model_args pretrained=meta-llama/Llama-4-Scout-17B-16E-Ins
truct,tensor_parallel_size=8,max_model_len=32768,gpu_memory_utilization=0.9,enable_multimodal_encoder_data_parallel=true --tasks chartqa --batch_size auto --apply_chat_template
Tasks Version Filter n-shot Metric Value Stderr
chartqa 0 none 0 anywhere_accuracy 0.8852 ± 0.0064
none 0 exact_match 0.6492 ± 0.0095
none 0 relaxed_accuracy 0.8820 ± 0.0065

perf result

HF_CHECKPOINT=meta-llama/Llama-4-Scout-17B-16E-Instruct
 python benchmarks/benchmark_serving.py --backend openai-chat --model $HF_CHECKPOINT --dataset-name hf --dataset-path lmarena-ai/VisionArena-Chat --hf-split train --num-prompts 1000 --endpoint /v1/chat/completions --max-concurrency 32 --ignore-eos --seed 0

DP8

============ Serving Benchmark Result ============
Successful requests:                     1000      
Benchmark duration (s):                  115.99    
Total input tokens:                      87321     
Total generated tokens:                  128000    
Request throughput (req/s):              8.62      
Output token throughput (tok/s):         1103.55   
Total Token throughput (tok/s):          1856.38   
---------------Time to First Token----------------
Mean TTFT (ms):                          641.44    
Median TTFT (ms):                        571.65    
P99 TTFT (ms):                           2443.98   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          23.86     
Median TPOT (ms):                        23.52     
P99 TPOT (ms):                           31.16     
---------------Inter-token Latency----------------
Mean ITL (ms):                           23.69     
Median ITL (ms):                         17.91     
P99 ITL (ms):                            222.95    
==================================================

Baseline TP8

============ Serving Benchmark Result ============
Successful requests:                     1000      
Benchmark duration (s):                  124.06    
Total input tokens:                      87321     
Total generated tokens:                  128000    
Request throughput (req/s):              8.06      
Output token throughput (tok/s):         1031.74   
Total Token throughput (tok/s):          1735.58   
---------------Time to First Token----------------
Mean TTFT (ms):                          758.51    
Median TTFT (ms):                        734.38    
P99 TTFT (ms):                           1805.42   
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          24.95     
Median TPOT (ms):                        24.71     
P99 TPOT (ms):                           30.54     
---------------Inter-token Latency----------------
Mean ITL (ms):                           24.76     
Median ITL (ms):                         17.87     
P99 ITL (ms):                            283.55    
==================================================

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Collaborator

@sarckk sarckk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks good to me overall, could you also add MM eval results with DP and TP?

@DarkLight1337
Copy link
Member

cc @houseroad

Copy link
Member

@ywang96 ywang96 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the delayed review and thank you for the contribution! Overall and I left some comments!

Comment on lines 781 to 793
num_chunks = flat_data.shape[0]
mp_world_size = get_tensor_model_parallel_world_size()
chunk_per_rank = (num_chunks + mp_world_size - 1) // mp_world_size
pad = (0, 0, 0, 0, 0, 0, 0,
chunk_per_rank * mp_world_size - num_chunks)
flat_data_padded = torch.nn.functional.pad(flat_data, pad)
rank = get_tensor_model_parallel_rank()
data_per_rank = flat_data_padded[rank * chunk_per_rank:(rank + 1) *
chunk_per_rank, ...].clone()
vision_embeddings_flat = self.vision_model(data_per_rank)
vision_embeddings_flat = tensor_model_parallel_all_gather(
vision_embeddings_flat, dim=0)
vision_embeddings_flat = vision_embeddings_flat[:num_chunks, ...]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great! As a follow-up I think it makes sense to rewrite this into a separate function to be shared by other models since this is not specific to mllama4 vision encoder in particular!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe create an issue to follow up?

Copy link
Collaborator

@houseroad houseroad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks pretty reasonable. Could you rebase and address @ywang96 's feedback?

@mergify mergify bot added the multi-modality Related to multi-modality (#4194) label May 30, 2025
@jennyyyyzhen jennyyyyzhen force-pushed the main branch 2 times, most recently from 70aacd3 to ac307ed Compare May 30, 2025 20:58
pad = (0, ) * (2 * (image_input.dim() - 1)) + (0, num_padded_chunks)
image_input_padded = torch.nn.functional.pad(image_input, pad)
rank = get_tensor_model_parallel_rank()
image_input_per_rank = image_input_padded[rank *
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wondering if we need clone here?

Copy link
Collaborator

@houseroad houseroad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me.

@@ -390,3 +393,36 @@ def modality_group_func(mm_input: MultiModalKwargs) -> Union[str, int]:
return [
list(group) for _, group in groupby(mm_inputs, key=modality_group_func)
]


def run_dp_sharded_vision_model(image_input: torch.Tensor,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can add some unittest for function. (good for a follow up PR)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@houseroad follow up PR is here #19103

@@ -818,9 +914,12 @@ def load_weights(self, weights: Iterable[tuple[str,
assert loaded_language_model_params is not None
updated_params.update(loaded_language_model_params)

if self.use_data_parallel:
other_weights = self._consolidate_qkv_weights(other_weights)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we can add some sanity checks to ensure _consolidate_qkv_weights operate appropriately.

@houseroad houseroad added the ready ONLY add when PR is ready to merge/full CI is needed label May 31, 2025
@houseroad
Copy link
Collaborator

@ywang96 , could you give another pass?

Copy link
Member

@ywang96 ywang96 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please address Lu's comment otherwise LGTM!

@jennyyyyzhen jennyyyyzhen force-pushed the main branch 3 times, most recently from 4a0e16e to e2d3ee5 Compare June 2, 2025 05:56
yZhen and others added 5 commits June 1, 2025 22:57
Signed-off-by: yzhen <yzhen@devgpu093.cco2.facebook.com>
Signed-off-by: yzhen <yzhen@devgpu093.cco2.facebook.com>
Signed-off-by: yzhen <yzhen@devgpu093.cco2.facebook.com>
Signed-off-by: yzhen <yzhen@devgpu093.cco2.facebook.com>
Signed-off-by: yzhen <yzhen@devgpu093.cco2.facebook.com>
@houseroad houseroad merged commit ebb1ec9 into vllm-project:main Jun 2, 2025
67 checks passed
cryptopic added a commit to cryptopic/vllm that referenced this pull request Jun 4, 2025
Summary:
Add unit test for run_dp_sharded_vision_model, following up on vllm-project#18368

  pytest tests/multimodal/test_utils.py -k "test_run_dp_sharded_vision_model"

=3 passed, 44 deselected, 5 warnings in 37.76s =

Signed-off-by: Siqi Yan <siqi@meta.com>
mmontuori pushed a commit to mmontuori/vllm that referenced this pull request Jun 5, 2025
…18368)

Signed-off-by: yzhen <yzhen@devgpu093.cco2.facebook.com>
Co-authored-by: yZhen <yZhen@fb.com>
Co-authored-by: yzhen <yzhen@devgpu093.cco2.facebook.com>
minpeter pushed a commit to minpeter/vllm that referenced this pull request Jun 24, 2025
…18368)

Signed-off-by: yzhen <yzhen@devgpu093.cco2.facebook.com>
Co-authored-by: yZhen <yZhen@fb.com>
Co-authored-by: yzhen <yzhen@devgpu093.cco2.facebook.com>
Signed-off-by: minpeter <kali2005611@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
multi-modality Related to multi-modality (#4194) ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants