Skip to content

[Bugfix][VLM] Fix Fuyu batching inference with max_num_seqs>1 #8892

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions tests/models/decoder_only/vision_language/test_fuyu.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ def run_test(

# max_model_len should be greater than image_feature_size
with vllm_runner(model,
max_model_len=2560,
max_num_seqs=1,
max_model_len=2048,
max_num_seqs=2,
dtype=dtype,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
Expand All @@ -80,8 +80,6 @@ def run_test(
]

with hf_runner(model, dtype=dtype) as hf_model:
hf_model.model.get_output_embeddings = lambda: \
hf_model.model.language_model.get_output_embeddings()
eos_token_id = hf_model.processor.tokenizer.eos_token_id
hf_outputs_per_image = [
hf_model.generate_greedy_logprobs_limit(prompts,
Expand Down
51 changes: 35 additions & 16 deletions vllm/model_executor/models/fuyu.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
SequenceData)

from .interfaces import SupportsMultiModal
from .utils import merge_multimodal_embeddings
from .utils import flatten_bn, merge_multimodal_embeddings

# Cannot find the following 2 numbers from hf config.
_IMAGE_TOKEN_ID = 71011
Expand Down Expand Up @@ -165,7 +165,7 @@ def input_processor_for_fuyu(ctx: InputContext, llm_inputs: LLMInputs):
model_config.model)

model_image_input = _fuyu_image_preprocess(image_processor, image_data)
image_patches = torch.stack([
image_patches = torch.cat([
image_patch[0]
for image_patch in model_image_input["image_patches"]
])
Expand Down Expand Up @@ -210,7 +210,7 @@ def input_mapper_for_fuyu(ctx: InputContext, data: object):
])

# image has been processed with prompt in input processor
return MultiModalInputs({"image_patches": data})
return MultiModalInputs({"pixel_values": data})


@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_fuyu)
Expand Down Expand Up @@ -242,23 +242,42 @@ def __init__(self,
cache_config=cache_config,
quant_config=quant_config)

def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:

h = w = self.config.patch_size
num_channels = self.config.num_channels
expected_dims = num_channels * h * w

def _validate_shape(d: torch.Tensor):
actual_dims = d.size(-1)

if actual_dims != expected_dims:
expected_expr = str(expected_dims)
raise ValueError(
"The expected shape of pixel values per image per batch "
f" per patch is {expected_expr}. "
f"You supplied {tuple(d.shape)}.")

for d in data:
_validate_shape(d)

return data.to(self.vision_embed_tokens.weight.dtype)

def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[FuyuImagePixelInputs]:
image_patches = kwargs.pop("image_patches", None)
pixel_values = kwargs.pop("pixel_values", None)

if isinstance(image_patches, torch.Tensor):
# Remove the N dimension until multiple images are supported.
image_patches = image_patches.squeeze(1)
if pixel_values is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of image patches. "
f"Got type: {type(pixel_values)}")

return FuyuImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(
flatten_bn(pixel_values, concat=True)),
)

expected_feature_size = self.image_feature_size
if image_patches.size(-1) != expected_feature_size:
raise ValueError(
f"Expected image patches to have the last dimension of "
f"{expected_feature_size}, got {image_patches.size(-1)}")
image_patches = image_patches.to(
self.vision_embed_tokens.weight.dtype)
return FuyuImagePixelInputs(type="pixel_values",
data=image_patches)
return None

def _process_image_input(
Expand Down
Loading