Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[VLM] Support caching in merged multi-modal processor #11396

Merged
merged 82 commits into from
Dec 27, 2024
Merged
Changes from 1 commit
Commits
Show all changes
82 commits
Select commit Hold shift + click to select a range
faa9b84
Refactor multi-modal processor to support caching
DarkLight1337 Dec 19, 2024
9711a15
Clean up
DarkLight1337 Dec 19, 2024
29e3fcd
Fix cached result being mutated
DarkLight1337 Dec 19, 2024
ab64e85
Rename
DarkLight1337 Dec 19, 2024
81215a2
Fix docs
DarkLight1337 Dec 19, 2024
cf52b3b
Fix a typo
DarkLight1337 Dec 19, 2024
a4a8eb9
Fix unhandled sampling rate in initialization
DarkLight1337 Dec 19, 2024
c48f7c5
format
DarkLight1337 Dec 19, 2024
b84ff42
Change the delimiter
DarkLight1337 Dec 19, 2024
c3f1bde
Fix extra dimension
DarkLight1337 Dec 19, 2024
32e5197
Update
DarkLight1337 Dec 19, 2024
7264d4e
Use the inner processor to enable fine-grained caching
DarkLight1337 Dec 20, 2024
02ea829
Make the cache optional
DarkLight1337 Dec 20, 2024
b981a9d
Fix invalid kwargs being passed to tokenizer
DarkLight1337 Dec 20, 2024
5dde7d0
Fix Phi3V prompt replacement
DarkLight1337 Dec 20, 2024
7339ab8
Refine
DarkLight1337 Dec 20, 2024
509411d
Enable fine-grained caching for audio models
DarkLight1337 Dec 20, 2024
c0454f5
Add fallback
DarkLight1337 Dec 20, 2024
d50ef03
Fix typo
DarkLight1337 Dec 20, 2024
81f7d61
Fix video processor for Qwen2-VL
DarkLight1337 Dec 20, 2024
13eede3
Merge branch 'main' into mm-processor-cache
DarkLight1337 Dec 20, 2024
affbc5c
Fix a bunch of type errors
DarkLight1337 Dec 20, 2024
b4ddfb1
Fix qwen2-vl
DarkLight1337 Dec 20, 2024
4b3db32
Fix
DarkLight1337 Dec 20, 2024
dafbc7f
Simplify Pixtral-HF
DarkLight1337 Dec 21, 2024
38aaff8
Cleanup
DarkLight1337 Dec 21, 2024
5fcb5d6
Fix Pixtral-HF
DarkLight1337 Dec 21, 2024
f86e148
Enable caching outside the processing loop
DarkLight1337 Dec 21, 2024
337f0d2
Make debugging easier
DarkLight1337 Dec 21, 2024
c01d38a
Update
DarkLight1337 Dec 21, 2024
84f02fb
Fix ultravox
DarkLight1337 Dec 21, 2024
9f417c2
Revert some unnecessary changes
DarkLight1337 Dec 21, 2024
00b765b
Merge branch 'main' into mm-fields
DarkLight1337 Dec 22, 2024
2ed431e
Add test and fix some issues
DarkLight1337 Dec 23, 2024
baaf551
Update
DarkLight1337 Dec 23, 2024
f5dbcb8
Fix
DarkLight1337 Dec 23, 2024
afd3f4f
Rework
DarkLight1337 Dec 23, 2024
6172450
Rename the test
DarkLight1337 Dec 23, 2024
416943d
Update count
DarkLight1337 Dec 23, 2024
86f2786
Rename
DarkLight1337 Dec 23, 2024
f5b6214
Some fixes
DarkLight1337 Dec 23, 2024
8a68e87
Cleanup
DarkLight1337 Dec 23, 2024
ab7e84b
Skip unspecified fields
DarkLight1337 Dec 23, 2024
9f2cdaa
Fix equality checking
DarkLight1337 Dec 23, 2024
d11e833
Consolidate common code
DarkLight1337 Dec 23, 2024
5fee280
Improve error message
DarkLight1337 Dec 23, 2024
6182fd6
Cleanup
DarkLight1337 Dec 23, 2024
e1214cf
Fix Pixtral-HF
DarkLight1337 Dec 23, 2024
c717bce
Fix missing mm_count key
DarkLight1337 Dec 23, 2024
023890e
Fix qwen2-vl
DarkLight1337 Dec 23, 2024
b5e5b8a
Fix Qwen2-VL
DarkLight1337 Dec 23, 2024
cf24a1b
Fix Qwen2-VL and Qwen2-Audio
DarkLight1337 Dec 23, 2024
73271e9
Debug Phi3V
DarkLight1337 Dec 23, 2024
e30deec
Consolidate common code
DarkLight1337 Dec 23, 2024
ea6f8b5
Try to fix Phi3V and Ultravox
DarkLight1337 Dec 23, 2024
10ae755
Remove benchmark
DarkLight1337 Dec 23, 2024
85c5e2c
Fix token mismatch in Phi3V and Ultravox
DarkLight1337 Dec 23, 2024
4873ff8
Update max image tokens
DarkLight1337 Dec 23, 2024
4dbb5a3
Strictly check the number of placeholder tokens
DarkLight1337 Dec 23, 2024
6dbae81
Fix doc failure
DarkLight1337 Dec 23, 2024
fb51c9b
Test and fix Mantis processor
DarkLight1337 Dec 24, 2024
91cbd63
Fix embedding inputs
DarkLight1337 Dec 24, 2024
6bee6ba
Update entrypoints tests
DarkLight1337 Dec 24, 2024
cfa2ce8
Merge branch 'main' into mm-fields
DarkLight1337 Dec 24, 2024
fa54292
Clean up
DarkLight1337 Dec 24, 2024
cbf79be
Avoid extra placeholder in phi3v
DarkLight1337 Dec 24, 2024
9cd38b1
Fix OOM
DarkLight1337 Dec 24, 2024
14dcdd5
Fix mantis processor
DarkLight1337 Dec 24, 2024
b8bd2d4
Merge branch 'main' into mm-fields
DarkLight1337 Dec 24, 2024
5045d93
Remove redundant code
DarkLight1337 Dec 24, 2024
4cac998
Still need Mantis repo for testing
DarkLight1337 Dec 24, 2024
e8afd10
Merge branch 'main' into mm-fields
DarkLight1337 Dec 25, 2024
93bba0a
Fix incorrect max image tokens (Updated in #11258)
DarkLight1337 Dec 25, 2024
ea9f888
Also cache by model ID
DarkLight1337 Dec 25, 2024
58747f6
Format
DarkLight1337 Dec 25, 2024
323657a
Update link
DarkLight1337 Dec 25, 2024
695c79e
Merge branch 'main' into mm-fields
DarkLight1337 Dec 26, 2024
c67efda
Address some comments
DarkLight1337 Dec 26, 2024
d4abec7
Move `MultiModalDataItems` to `inputs` module to address more comments
DarkLight1337 Dec 26, 2024
9f4a8be
Add documentation
DarkLight1337 Dec 26, 2024
1d5b56d
Fix circular import
DarkLight1337 Dec 26, 2024
e4c7a14
Update docs
DarkLight1337 Dec 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Enable caching outside the processing loop
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
  • Loading branch information
DarkLight1337 committed Dec 21, 2024
commit f86e148a5bf3a04c403e14ffc71b2828d42980e4
19 changes: 15 additions & 4 deletions vllm/model_executor/models/llava.py
Original file line number Diff line number Diff line change
@@ -20,7 +20,8 @@
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import NestedTensors
from vllm.multimodal.inputs import (MultiModalField, MultiModalFields,
MultiModalKwargs, NestedTensors)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessorInputs,
PromptReplacement)
@@ -139,11 +140,22 @@ def _call_hf_processor(

return processed_outputs

def _get_mm_fields(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalField]:
return dict(
pixel_values=MultiModalFields.index("image"),
image_embeds=MultiModalFields.index("image"),
is_pixtral=MultiModalFields.index("image"),
)

def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_inputs: BatchFeature,
hf_mm_kwargs: Mapping[str, object],
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
hf_config = self.ctx.get_hf_config(LlavaConfig)
image_token_id = hf_config.image_token_index
@@ -216,7 +228,6 @@ def _get_dummy_mm_inputs(
return ProcessorInputs(
prompt_text=image_token * num_images,
mm_data=data,
hf_mm_kwargs={},
)


25 changes: 18 additions & 7 deletions vllm/model_executor/models/phi3v.py
Original file line number Diff line number Diff line change
@@ -32,7 +32,8 @@
from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import NestedTensors
from vllm.multimodal.inputs import (MultiModalField, MultiModalFields,
MultiModalKwargs, NestedTensors)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessorInputs,
PromptReplacement)
@@ -306,11 +307,11 @@ def get_max_phi3v_image_tokens(
*,
num_crops: Optional[int] = None,
) -> int:
hf_mm_kwargs = {}
hf_processor_mm_kwargs = {}
if num_crops:
hf_mm_kwargs["num_crops"] = num_crops
hf_processor_mm_kwargs["num_crops"] = num_crops

processor = ctx.get_hf_processor(**hf_mm_kwargs)
processor = ctx.get_hf_processor(**hf_processor_mm_kwargs)

return processor.calc_num_image_tokens_from_image_size(
width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
@@ -350,11 +351,22 @@ def _call_hf_processor(

return processed_outputs

def _get_mm_fields(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalField]:
return dict(
pixel_values=MultiModalFields.index("image"),
image_sizes=MultiModalFields.index("image"),
image_embeds=MultiModalFields.index("image"),
)

def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_inputs: BatchFeature,
hf_mm_kwargs: Mapping[str, object],
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
hf_processor = self._get_hf_processor()
image_tokens: list[str] = hf_processor.img_tokens # type: ignore
@@ -399,7 +411,6 @@ def _get_dummy_mm_inputs(
return ProcessorInputs(
prompt_text="".join(image_tokens[:num_images]),
mm_data=data,
hf_mm_kwargs={},
)


33 changes: 17 additions & 16 deletions vllm/model_executor/models/qwen2_audio.py
Original file line number Diff line number Diff line change
@@ -38,7 +38,8 @@
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import NestedTensors
from vllm.multimodal.inputs import (MultiModalField, MultiModalFields,
MultiModalKwargs, NestedTensors)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessorInputs,
PromptReplacement)
@@ -73,7 +74,7 @@ def forward(self, audio_features):


# From Qwen2AudioEncoder._get_feat_extract_output_lengths
def _get_feat_extract_output_lengths(input_lengths: torch.LongTensor):
def _get_feat_extract_output_lengths(input_lengths: torch.Tensor):
feat_lengths = (input_lengths - 1) // 2 + 1
output_lengths = (feat_lengths - 2) // 2 + 1
return feat_lengths, output_lengths
@@ -127,10 +128,6 @@ def _call_hf_processor(
mm_kwargs = dict(
**mm_kwargs,
sampling_rate=feature_extractor.sampling_rate,
# When fine-grained caching is applied,
# the individual processors are called separately.
return_attention_mask=True,
padding="max_length",
)
else:
# NOTE: WhisperFeatureExtractor cannot handle empty list of audios
@@ -142,27 +139,32 @@ def _call_hf_processor(
mm_kwargs=mm_kwargs,
)

# When fine-grained caching is applied,
# the individual processors are called separately.
if "attention_mask" in processed_outputs:
processed_outputs["feature_attention_mask"] = \
processed_outputs.pop("attention_mask")

return processed_outputs

def _get_mm_fields(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalField]:
return dict(
input_features=MultiModalFields.index("audio"),
feature_attention_mask=MultiModalFields.index("audio"),
)

def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_inputs: BatchFeature,
hf_mm_kwargs: Mapping[str, object],
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
hf_config = self.ctx.get_hf_config(Qwen2AudioConfig)
placeholder = hf_config.audio_token_index

feature_attention_mask = hf_inputs.get("feature_attention_mask")
feature_attention_mask = out_mm_kwargs.get("feature_attention_mask")
if feature_attention_mask is None:
audio_output_lengths = []
else:
assert isinstance(feature_attention_mask, torch.Tensor)
_, audio_output_lengths = _get_feat_extract_output_lengths(
feature_attention_mask.sum(-1))

@@ -192,7 +194,6 @@ def _get_dummy_mm_inputs(
return ProcessorInputs(
prompt_text="<|AUDIO|>" * audio_count,
mm_data=data,
hf_mm_kwargs={},
)


63 changes: 37 additions & 26 deletions vllm/model_executor/models/qwen2_vl.py
Original file line number Diff line number Diff line change
@@ -54,7 +54,9 @@
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalDataDict, NestedTensors
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalField,
MultiModalFields, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessorInputs,
PromptReplacement)
@@ -805,39 +807,19 @@ def _get_hf_mm_data(
and v[0].ndim == 2):
# Pass through embedding inputs (multi)
passthrough_data[f"{k}_embeds"] = v
else:
elif len(v) > 0:
# Map keys to plural form, e.g.: image -> images
processor_data[f"{k}s"] = v
else:
processor_data[k] = v

return processor_data, passthrough_data

def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
processed_outputs = super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
)

# Remove the extra dimension
if (not self.ctx.model_config.disable_mm_preprocessor_cache
and "pixel_values" in processed_outputs):
processed_outputs["pixel_values"] = \
processed_outputs["pixel_values"].squeeze(0)

return processed_outputs

def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_inputs: BatchFeature,
hf_mm_kwargs: Mapping[str, object],
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
hf_processor = self._get_hf_processor()
image_processor = _get_image_processor(hf_processor)
@@ -851,7 +833,9 @@ def _get_prompt_replacements(
merge_length = image_processor.merge_size**2

def get_replacement_qwen2vl(item_idx: int, modality: str):
grid_thw = hf_inputs[f"{modality}_grid_thw"][item_idx]
grid_thw = out_mm_kwargs[f"{modality}_grid_thw"][item_idx]
assert isinstance(grid_thw, torch.Tensor)

num_tokens = grid_thw.prod() // merge_length
return placeholder[modality] * num_tokens

@@ -864,6 +848,34 @@ def get_replacement_qwen2vl(item_idx: int, modality: str):
) for modality in ("image", "video")
]

def _get_mm_fields(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalField]:
image_grid_thw = hf_inputs.get("image_grid_thw", torch.empty((0, 3)))
image_slice_idxs = image_grid_thw.prod(-1).tolist() + [None]
image_slices = [
slice(image_slice_idxs[i], image_slice_idxs[i + i])
for i in range(len(image_grid_thw))
]

video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
video_slice_idxs = video_grid_thw.prod(-1).tolist() + [None]
video_slices = [
slice(video_slice_idxs[i], video_slice_idxs[i + i])
for i in range(len(video_grid_thw))
]

return dict(
pixel_values=MultiModalFields.flat("image", image_slices),
image_embeds=MultiModalFields.flat("image", image_slices),
image_grid_thw=MultiModalFields.index("image"),
pixel_values_videos=MultiModalFields.flat("video", video_slices),
video_embeds=MultiModalFields.flat("video", video_slices),
video_grid_thw=MultiModalFields.index("video"),
)

def _get_dummy_mm_inputs(
self,
mm_counts: Mapping[str, int],
@@ -889,7 +901,6 @@ def _get_dummy_mm_inputs(
return ProcessorInputs(
prompt_text=image_token * num_images,
mm_data=data,
hf_mm_kwargs={},
)


21 changes: 16 additions & 5 deletions vllm/model_executor/models/ultravox.py
Original file line number Diff line number Diff line change
@@ -23,7 +23,9 @@
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.model_loader.loader import DefaultModelLoader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalField, MultiModalFields,
MultiModalKwargs, NestedTensors)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessorInputs,
PromptReplacement)
@@ -144,17 +146,27 @@ def _call_hf_processor(
)
return BatchFeature(combined_outputs)

def _get_mm_fields(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalField]:
return dict(
audio_features=MultiModalFields.index("audio"),
audio_embeds=MultiModalFields.index("audio"),
)

def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_inputs: BatchFeature,
hf_mm_kwargs: Mapping[str, object],
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
hf_processor = self._get_hf_processor()
placeholder = hf_processor.audio_token_replacement # type: ignore

def get_replacement_ultravox(item_idx: int):
audio_token_len = hf_inputs["audio_token_len"][item_idx]
audio_token_len = out_mm_kwargs["audio_token_len"][item_idx]
return placeholder * audio_token_len

return [
@@ -180,7 +192,6 @@ def _get_dummy_mm_inputs(
return ProcessorInputs(
prompt_text="<|audio|>" * audio_count,
mm_data=data,
hf_mm_kwargs={},
)


Loading
Loading