-
-
Notifications
You must be signed in to change notification settings - Fork 8.4k
[Model] enable data parallel for Llama4 vision encoder #18368
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this looks good to me overall, could you also add MM eval results with DP and TP?
cc @houseroad |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the delayed review and thank you for the contribution! Overall and I left some comments!
num_chunks = flat_data.shape[0] | ||
mp_world_size = get_tensor_model_parallel_world_size() | ||
chunk_per_rank = (num_chunks + mp_world_size - 1) // mp_world_size | ||
pad = (0, 0, 0, 0, 0, 0, 0, | ||
chunk_per_rank * mp_world_size - num_chunks) | ||
flat_data_padded = torch.nn.functional.pad(flat_data, pad) | ||
rank = get_tensor_model_parallel_rank() | ||
data_per_rank = flat_data_padded[rank * chunk_per_rank:(rank + 1) * | ||
chunk_per_rank, ...].clone() | ||
vision_embeddings_flat = self.vision_model(data_per_rank) | ||
vision_embeddings_flat = tensor_model_parallel_all_gather( | ||
vision_embeddings_flat, dim=0) | ||
vision_embeddings_flat = vision_embeddings_flat[:num_chunks, ...] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is great! As a follow-up I think it makes sense to rewrite this into a separate function to be shared by other models since this is not specific to mllama4 vision encoder in particular!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe create an issue to follow up?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks pretty reasonable. Could you rebase and address @ywang96 's feedback?
70aacd3
to
ac307ed
Compare
pad = (0, ) * (2 * (image_input.dim() - 1)) + (0, num_padded_chunks) | ||
image_input_padded = torch.nn.functional.pad(image_input, pad) | ||
rank = get_tensor_model_parallel_rank() | ||
image_input_per_rank = image_input_padded[rank * |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wondering if we need clone here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me.
@@ -390,3 +393,36 @@ def modality_group_func(mm_input: MultiModalKwargs) -> Union[str, int]: | |||
return [ | |||
list(group) for _, group in groupby(mm_inputs, key=modality_group_func) | |||
] | |||
|
|||
|
|||
def run_dp_sharded_vision_model(image_input: torch.Tensor, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can add some unittest for function. (good for a follow up PR)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@houseroad follow up PR is here #19103
@@ -818,9 +914,12 @@ def load_weights(self, weights: Iterable[tuple[str, | |||
assert loaded_language_model_params is not None | |||
updated_params.update(loaded_language_model_params) | |||
|
|||
if self.use_data_parallel: | |||
other_weights = self._consolidate_qkv_weights(other_weights) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: we can add some sanity checks to ensure _consolidate_qkv_weights operate appropriately.
@ywang96 , could you give another pass? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please address Lu's comment otherwise LGTM!
4a0e16e
to
e2d3ee5
Compare
Signed-off-by: yzhen <yzhen@devgpu093.cco2.facebook.com>
Signed-off-by: yzhen <yzhen@devgpu093.cco2.facebook.com>
Signed-off-by: yzhen <yzhen@devgpu093.cco2.facebook.com>
Signed-off-by: yzhen <yzhen@devgpu093.cco2.facebook.com>
Summary: Add unit test for run_dp_sharded_vision_model, following up on vllm-project#18368 pytest tests/multimodal/test_utils.py -k "test_run_dp_sharded_vision_model" =3 passed, 44 deselected, 5 warnings in 37.76s = Signed-off-by: Siqi Yan <siqi@meta.com>
…18368) Signed-off-by: yzhen <yzhen@devgpu093.cco2.facebook.com> Co-authored-by: yZhen <yZhen@fb.com> Co-authored-by: yzhen <yzhen@devgpu093.cco2.facebook.com>
…18368) Signed-off-by: yzhen <yzhen@devgpu093.cco2.facebook.com> Co-authored-by: yZhen <yZhen@fb.com> Co-authored-by: yzhen <yzhen@devgpu093.cco2.facebook.com> Signed-off-by: minpeter <kali2005611@gmail.com>
Summary:
Llama4 vision encoder in dp8 is ~3x as fast as in tp8, especially when handling a large number of input images (eg. 9 images per request).
Add an enable_vision_encoder_data_parallel to allow using different parallelism for vision model and language model.
Unit test
pytest tests/models/multimodal/generation/test_common.py -k "llama4"
MM Eval
Baseline TP8
DP8
perf result
DP8
Baseline TP8