Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[VLM] Support caching in merged multi-modal processor #11396

Merged
merged 82 commits into from
Dec 27, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
82 commits
Select commit Hold shift + click to select a range
faa9b84
Refactor multi-modal processor to support caching
DarkLight1337 Dec 19, 2024
9711a15
Clean up
DarkLight1337 Dec 19, 2024
29e3fcd
Fix cached result being mutated
DarkLight1337 Dec 19, 2024
ab64e85
Rename
DarkLight1337 Dec 19, 2024
81215a2
Fix docs
DarkLight1337 Dec 19, 2024
cf52b3b
Fix a typo
DarkLight1337 Dec 19, 2024
a4a8eb9
Fix unhandled sampling rate in initialization
DarkLight1337 Dec 19, 2024
c48f7c5
format
DarkLight1337 Dec 19, 2024
b84ff42
Change the delimiter
DarkLight1337 Dec 19, 2024
c3f1bde
Fix extra dimension
DarkLight1337 Dec 19, 2024
32e5197
Update
DarkLight1337 Dec 19, 2024
7264d4e
Use the inner processor to enable fine-grained caching
DarkLight1337 Dec 20, 2024
02ea829
Make the cache optional
DarkLight1337 Dec 20, 2024
b981a9d
Fix invalid kwargs being passed to tokenizer
DarkLight1337 Dec 20, 2024
5dde7d0
Fix Phi3V prompt replacement
DarkLight1337 Dec 20, 2024
7339ab8
Refine
DarkLight1337 Dec 20, 2024
509411d
Enable fine-grained caching for audio models
DarkLight1337 Dec 20, 2024
c0454f5
Add fallback
DarkLight1337 Dec 20, 2024
d50ef03
Fix typo
DarkLight1337 Dec 20, 2024
81f7d61
Fix video processor for Qwen2-VL
DarkLight1337 Dec 20, 2024
13eede3
Merge branch 'main' into mm-processor-cache
DarkLight1337 Dec 20, 2024
affbc5c
Fix a bunch of type errors
DarkLight1337 Dec 20, 2024
b4ddfb1
Fix qwen2-vl
DarkLight1337 Dec 20, 2024
4b3db32
Fix
DarkLight1337 Dec 20, 2024
dafbc7f
Simplify Pixtral-HF
DarkLight1337 Dec 21, 2024
38aaff8
Cleanup
DarkLight1337 Dec 21, 2024
5fcb5d6
Fix Pixtral-HF
DarkLight1337 Dec 21, 2024
f86e148
Enable caching outside the processing loop
DarkLight1337 Dec 21, 2024
337f0d2
Make debugging easier
DarkLight1337 Dec 21, 2024
c01d38a
Update
DarkLight1337 Dec 21, 2024
84f02fb
Fix ultravox
DarkLight1337 Dec 21, 2024
9f417c2
Revert some unnecessary changes
DarkLight1337 Dec 21, 2024
00b765b
Merge branch 'main' into mm-fields
DarkLight1337 Dec 22, 2024
2ed431e
Add test and fix some issues
DarkLight1337 Dec 23, 2024
baaf551
Update
DarkLight1337 Dec 23, 2024
f5dbcb8
Fix
DarkLight1337 Dec 23, 2024
afd3f4f
Rework
DarkLight1337 Dec 23, 2024
6172450
Rename the test
DarkLight1337 Dec 23, 2024
416943d
Update count
DarkLight1337 Dec 23, 2024
86f2786
Rename
DarkLight1337 Dec 23, 2024
f5b6214
Some fixes
DarkLight1337 Dec 23, 2024
8a68e87
Cleanup
DarkLight1337 Dec 23, 2024
ab7e84b
Skip unspecified fields
DarkLight1337 Dec 23, 2024
9f2cdaa
Fix equality checking
DarkLight1337 Dec 23, 2024
d11e833
Consolidate common code
DarkLight1337 Dec 23, 2024
5fee280
Improve error message
DarkLight1337 Dec 23, 2024
6182fd6
Cleanup
DarkLight1337 Dec 23, 2024
e1214cf
Fix Pixtral-HF
DarkLight1337 Dec 23, 2024
c717bce
Fix missing mm_count key
DarkLight1337 Dec 23, 2024
023890e
Fix qwen2-vl
DarkLight1337 Dec 23, 2024
b5e5b8a
Fix Qwen2-VL
DarkLight1337 Dec 23, 2024
cf24a1b
Fix Qwen2-VL and Qwen2-Audio
DarkLight1337 Dec 23, 2024
73271e9
Debug Phi3V
DarkLight1337 Dec 23, 2024
e30deec
Consolidate common code
DarkLight1337 Dec 23, 2024
ea6f8b5
Try to fix Phi3V and Ultravox
DarkLight1337 Dec 23, 2024
10ae755
Remove benchmark
DarkLight1337 Dec 23, 2024
85c5e2c
Fix token mismatch in Phi3V and Ultravox
DarkLight1337 Dec 23, 2024
4873ff8
Update max image tokens
DarkLight1337 Dec 23, 2024
4dbb5a3
Strictly check the number of placeholder tokens
DarkLight1337 Dec 23, 2024
6dbae81
Fix doc failure
DarkLight1337 Dec 23, 2024
fb51c9b
Test and fix Mantis processor
DarkLight1337 Dec 24, 2024
91cbd63
Fix embedding inputs
DarkLight1337 Dec 24, 2024
6bee6ba
Update entrypoints tests
DarkLight1337 Dec 24, 2024
cfa2ce8
Merge branch 'main' into mm-fields
DarkLight1337 Dec 24, 2024
fa54292
Clean up
DarkLight1337 Dec 24, 2024
cbf79be
Avoid extra placeholder in phi3v
DarkLight1337 Dec 24, 2024
9cd38b1
Fix OOM
DarkLight1337 Dec 24, 2024
14dcdd5
Fix mantis processor
DarkLight1337 Dec 24, 2024
b8bd2d4
Merge branch 'main' into mm-fields
DarkLight1337 Dec 24, 2024
5045d93
Remove redundant code
DarkLight1337 Dec 24, 2024
4cac998
Still need Mantis repo for testing
DarkLight1337 Dec 24, 2024
e8afd10
Merge branch 'main' into mm-fields
DarkLight1337 Dec 25, 2024
93bba0a
Fix incorrect max image tokens (Updated in #11258)
DarkLight1337 Dec 25, 2024
ea9f888
Also cache by model ID
DarkLight1337 Dec 25, 2024
58747f6
Format
DarkLight1337 Dec 25, 2024
323657a
Update link
DarkLight1337 Dec 25, 2024
695c79e
Merge branch 'main' into mm-fields
DarkLight1337 Dec 26, 2024
c67efda
Address some comments
DarkLight1337 Dec 26, 2024
d4abec7
Move `MultiModalDataItems` to `inputs` module to address more comments
DarkLight1337 Dec 26, 2024
9f4a8be
Add documentation
DarkLight1337 Dec 26, 2024
1d5b56d
Fix circular import
DarkLight1337 Dec 26, 2024
e4c7a14
Update docs
DarkLight1337 Dec 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fix Qwen2-VL and Qwen2-Audio
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
  • Loading branch information
DarkLight1337 committed Dec 23, 2024
commit cf24a1b2077d8def5ba7ca027caf8ce7006c070d
172 changes: 172 additions & 0 deletions benchmarks/mmmu_bench.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
r"""Benchmark offline inference throughput with MMMU-PRO Vision
e.g,
python3 benchmarks/mmmu_bench.py \
--model mistralai/Pixtral-12B-2409 \
--tokenizer-mode mistral \
--num-prompts 1000 \
--image-hit-rate 0.5

python3 benchmarks/mmmu_bench.py \
--model allenai/Molmo-72B-0924 \
--tensor-parallel-size 4 \
--trust-remote-code \
--num-prompts 1000
"""
import argparse
import asyncio
import base64
import dataclasses
import io
import math
import random
import time
from itertools import chain

from datasets import load_dataset
from PIL import Image

from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.entrypoints.chat_utils import load_chat_template
from vllm.utils import FlexibleArgumentParser


def sample_mmmu_pro_vision_requests(
dataset,
num_requests: int,
image_hit_rate: float,
):
sampled_requests = []
num_unique_images = max(int(num_requests * (1 - image_hit_rate)), 1)
print(
f"Total {num_requests} requests with {num_unique_images} unique images"
)
dataset = dataset.take(num_unique_images)

# The dataset with streaming=True fetches (downloads) 64 rows at a time.
print("Fetching data. This may take a while...")
for data in dataset:
if len(sampled_requests) == num_requests:
break

# MMMU-Pro vision direct prompt
# Ref: https://github.com/MMMU-Benchmark/MMMU/blob/6ce42f4d8f70c1841c67867152648974415b5cac/mmmu-pro/prompts.yaml#L5
prompt = (
"Answer with the option letter from the given choices directly. "
"The last line of your response should be of the following "
"format: 'Answer: $LETTER' (without quotes) where LETTER is one of "
"options.")

image: Image.Image = data["image"]
image = image.convert("RGB")
image_data = io.BytesIO()
image.save(image_data, format='JPEG')
image_base64 = base64.b64encode(image_data.getvalue()).decode("utf-8")
mm_content = {
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{image_base64}"
},
}

messages = [{
"role":
"user",
"content": [
{
"type": "text",
"text": prompt
},
mm_content,
],
}]
sampled_requests.append(messages)

n = math.ceil(num_requests / num_unique_images)
sampled_requests = list(
chain.from_iterable([x] * n for x in sampled_requests))[:num_requests]

return sampled_requests


def sample_hf_requests(
num_requests: int,
random_seed: int,
image_hit_rate: float,
):
dataset = load_dataset('MMMU/MMMU_Pro',
name='vision',
split="test",
streaming=True)
dataset = dataset.shuffle(seed=random_seed)
return sample_mmmu_pro_vision_requests(dataset, num_requests,
image_hit_rate)


def initialize_llm(engine_args):
print("Initializing LLM...")
return LLM(**dataclasses.asdict(engine_args))


async def main(args: argparse.Namespace):
print(args)
random.seed(args.seed)
engine_args = EngineArgs.from_cli_args(args)

sampling_params = SamplingParams(max_tokens=args.output_len, temperature=0)
chat_template = load_chat_template(args.chat_template)

# Concurrently initialize the LLM and sample data. Note that since
# both initialize_llm and sample_hf_requests are blocking, we need to
# use asyncio.to_thread to create async coroutines.
st = time.perf_counter()
sampling_task = asyncio.create_task(
asyncio.to_thread(sample_hf_requests, args.num_prompts, args.seed,
args.image_hit_rate))
llm_task = asyncio.create_task(
asyncio.to_thread(initialize_llm, engine_args))

sampled, llm = await asyncio.gather(sampling_task, llm_task)
print(f"Data sampling + LLM init time: {time.perf_counter() - st:.2f}s")

st = time.perf_counter()
outputs = llm.chat(sampled,
sampling_params=sampling_params,
chat_template=chat_template)
duration = time.perf_counter() - st

total_generated_tokens = sum(
len(output.outputs[0].token_ids) for output in outputs)

print(f"Request throughput: {args.num_prompts / duration:.2f} req/s")
print(f"Total generated tokens: {total_generated_tokens}")
print(
f"Token generation rate: {total_generated_tokens / duration:.2f} tok/s"
)


if __name__ == "__main__":
parser = FlexibleArgumentParser(description="Benchmark the throughput.")
parser.add_argument("--output-len",
type=int,
default=128,
help="Output length for each request. Overrides the "
"output length from the dataset.")
parser.add_argument("--num-prompts",
type=int,
default=1000,
help="Number of prompts to process.")
parser.add_argument("--image-hit-rate",
type=float,
default=0.0,
help="Image hit rate between 0 and 1.")
parser.add_argument("--chat-template",
type=str,
default=None,
help="Set the chat template to use.")
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
if args.tokenizer is None:
args.tokenizer = args.model

asyncio.run(main(args))
14 changes: 6 additions & 8 deletions tests/multimodal/test_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,14 +509,12 @@ def _rand_video(
min_wh: int,
max_wh: int,
):
num_frames = rng.randint(min_frames, max_frames)
w, h = rng.randint(min_wh, max_wh, size=(2, ))

# Temporary fix. Qwen2-VL video processor fails on video of shape
# (b, 199, 178, 3) where b in (3, 5, 7)
w = (w // 32) * 32
h = (h // 32) * 32
num_frames = rng.randint(min_frames, max_frames)
num_frames = (num_frames // 2) * 2

w, h = rng.randint(min_wh, max_wh, size=(2, ))
return rng.randint(0, 255, size=(num_frames, w, h, 3), dtype=np.uint8)


Expand All @@ -527,7 +525,7 @@ def _rand_audio(
sr: int,
):
audio_len = rng.randint(min_len, max_len)
return rng.randint(0, 255, size=(audio_len, ), dtype=np.uint8), sr
return rng.rand(audio_len), sr


# yapf: disable
Expand All @@ -542,7 +540,7 @@ def _rand_audio(
("fixie-ai/ultravox-v0_3", {"audio"}),
])
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
@pytest.mark.parametrize("num_batches", [10])
@pytest.mark.parametrize("num_batches", [32])
@pytest.mark.parametrize("simplify_rate", [1.0])
# yapf: enable
def test_processing_cache_correctness(
Expand Down Expand Up @@ -588,7 +586,7 @@ def test_processing_cache_correctness(
"video":
partial(_rand_video,
rng,
min_frames=1,
min_frames=2,
max_frames=8,
min_wh=128,
max_wh=256),
Expand Down
32 changes: 18 additions & 14 deletions vllm/multimodal/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,23 +280,20 @@ def from_hf_inputs(
for key, config in config_by_key.items() if key in hf_inputs
}

if enable_sanity_checks:
batch_sizes = {k: len(v) for k, v in items_by_key.items()}
batch_size = next(iter(batch_sizes.values()), 0)
assert all(bs == batch_size for bs in batch_sizes.values()), dict(
batch_sizes=batch_sizes, items_by_key=items_by_key)

# NOTE: This skips fields in `hf_inputs` that are not in `config_by_key`
# We assume that those fields are not used in vLLM
data = {k: hf_inputs[k] for k in items_by_key}

return MultiModalKwargs(data, items_by_key=items_by_key)
return MultiModalKwargs(data,
items_by_key=items_by_key,
enable_sanity_checks=enable_sanity_checks)

def __init__(
self,
data: Mapping[str, NestedTensors],
*,
items_by_key: Optional[Mapping[str, list[MultiModalFieldItem]]] = None,
enable_sanity_checks: bool = False,
) -> None:
if items_by_key is None:
items_by_key = {}
Expand All @@ -313,6 +310,17 @@ def __init__(

self._keys_by_modality = dict(keys_by_modality)

if enable_sanity_checks:
for modality, keys in keys_by_modality.items():
items_in_modality = {k: items_by_key[k] for k in keys}
batch_sizes = {k: len(v) for k, v in items_in_modality.items()}
batch_size = next(iter(batch_sizes.values()), 0)
assert all(bs == batch_size
for bs in batch_sizes.values()), dict(
modality=modality,
batch_sizes=batch_sizes,
items_by_key=items_by_key)

@staticmethod
def _try_stack(nested_tensors: NestedTensors) -> NestedTensors:
"""
Expand Down Expand Up @@ -435,18 +443,14 @@ def from_items_by_modality(
for k, v in field.items():
items_by_key[k].append(v)

if enable_sanity_checks:
batch_sizes = {k: len(v) for k, v in items_by_key.items()}
batch_size = next(iter(batch_sizes.values()), 0)
assert all(bs == batch_size for bs in batch_sizes.values()), dict(
batch_sizes=batch_sizes, items_by_key=items_by_key)

data = {
k: items[0].field.reduce(items).data
for k, items in items_by_key.items()
}

return MultiModalKwargs(data, items_by_key=items_by_key)
return MultiModalKwargs(data,
items_by_key=items_by_key,
enable_sanity_checks=enable_sanity_checks)


MultiModalPlaceholderDict = Mapping[str, Sequence[PlaceholderRange]]
Expand Down
Loading