Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1][VLM] Proper memory profiling for image language models #11210

Merged
merged 9 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
iterate
Signed-off-by: Roger Wang <ywang@roblox.com>
  • Loading branch information
ywang96 committed Dec 15, 2024
commit b40fcffc4cf978e84bdd838c8d8a6e4ff69af50c
17 changes: 17 additions & 0 deletions vllm/multimodal/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,23 @@ def register_max_image_tokens(
"""
return self.register_max_multimodal_tokens("image", max_mm_tokens)

def get_max_tokens_per_item_by_modality(
self,
model_config: "ModelConfig",
) -> Mapping[str, int]:
"""
Get the maximum number of tokens per data item from each modality
for profiling the memory usage of a model.

Note:
This is currently only used in V1.
"""

return {
key: plugin.get_max_multimodal_tokens(model_config)
for key, plugin in self._plugins.items()
}

ywang96 marked this conversation as resolved.
Show resolved Hide resolved
def get_max_tokens_by_modality(
self,
model_config: "ModelConfig",
Expand Down
1 change: 1 addition & 0 deletions vllm/v1/engine/mm_input_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def cache_hit_ratio(self, steps):
logger.debug("MMInputMapper: cache_hit_ratio = %.2f ",
self.mm_cache_hits / self.mm_cache_total)

# TODO: Support modalities beyond image.
def process_inputs(
self,
mm_data: MultiModalDataDict,
Expand Down
25 changes: 17 additions & 8 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,23 +610,32 @@ def profile_run(self) -> None:

# Profile with multimodal encoder & encoder cache.
# TODO (ywang96): generalize this beyond image modality.
# since mm_input_mapper only supports image inputs.
if self.is_multimodal_model:

# Create dummy patch of multimodal inputs.
# Create dummy batch of multimodal inputs.
dummy_mm_data = self.input_registry \
.dummy_data_for_profiling(self.model_config,
self.max_num_tokens,
self.mm_registry).multi_modal_data
dummy_mm_kwargs, _ = self.mm_input_mapper.process_inputs(
dummy_mm_data, None, None, None)
max_token_per_image = self.mm_registry._plugins[
'image'].get_max_multimodal_tokens(self.model_config)
max_num_images = min(
mm_data=dummy_mm_data,
mm_hashes=None,
mm_processor_kwargs=None,
precomputed_mm_inputs=None)

# NOTE: Currently model is profiled with a single non-text
# modality even when it supports multiple.
max_tokens_per_mm_item = max(
self.mm_registry.get_max_tokens_per_item_by_modality(
self.model_config).values())

max_num_mm_items = min(
self.max_num_encoder_input_tokens,
self.encoder_cache_size) // max_token_per_image
self.encoder_cache_size) // max_tokens_per_mm_item

batched_dummy_mm_inputs = MultiModalKwargs.batch(
[dummy_mm_kwargs[0] for _ in range(max_num_images)])
[dummy_mm_kwargs[0] for _ in range(max_num_mm_items)])
ywang96 marked this conversation as resolved.
Show resolved Hide resolved
batched_dummy_mm_inputs = MultiModalKwargs.as_kwargs(
batched_dummy_mm_inputs, device=self.device)

Expand All @@ -635,7 +644,7 @@ def profile_run(self) -> None:
**batched_dummy_mm_inputs)

# Cache the dummy encoder outputs.
dummy_req_input_ids = [("0", i) for i in range(max_num_images)]
dummy_req_input_ids = [("0", i) for i in range(max_num_mm_items)]
self.encoder_cache["0"] = {}
for (req_id, input_id), output in zip(dummy_req_input_ids,
dummy_encoder_outputs):
Expand Down
Loading