Skip to content

[Feature][not ready]: Batch-Level Multimodal Embedding Mask Optimization #24456

@Ruihan11

Description

@Ruihan11

🚀 The feature, motivation and pitch

1. Motivation

Currently, merge_multimodal_embeddings scans input_ids individually for each request to find placeholder tokens. This is inefficient because the scheduler already has mm_positions data for all requests. We should pre-compute a batch-level mask (like grammar_bitmask) instead of scanning at runtime.

The Problem

  1. torch.isin(input_ids, placeholder_token_id) - Scans entire input_ids tensor to find multiple placeholder tokens
  2. (input_ids == placeholder_token_id) - Scans entire input_ids tensor to find single placeholder token

2. Proposed Changes

Phase 1: Core Function + Test

  • Add merge_multimodal_embeddings_with_mask() function to utils (/vllm/model_executor/models/utils.py)
  • Add unit test

Phase 2: Integration

  • Add mask generation from mm_positions to scheduler
  • Replace scanning calls with mask version

#23891
#16229

Alternatives

No response

Additional context

No response

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions