Skip to content

Commit

Permalink
Support different image sizes in prefill in VLMs
Browse files Browse the repository at this point in the history
When a batch contained images if different sizes during prefill, the
server would fail (see e.g. #2056). Images were processed separately and
then concatenated. However, this can fail for images with different sizes.

Fix this by preprocessing all images in the batch together, so that the
image processor can ensure that all image tensors have compatible sizes.
  • Loading branch information
danieldk committed Jun 13, 2024
1 parent 376a0b7 commit 7d3439f
Show file tree
Hide file tree
Showing 6 changed files with 291 additions and 27 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "eos_token",
"generated_tokens": 8,
"prefill": [],
"seed": null,
"tokens": [
{
"id": 2502,
"logprob": -1.734375,
"special": false,
"text": "image"
},
{
"id": 2196,
"logprob": -0.5756836,
"special": false,
"text": " result"
},
{
"id": 604,
"logprob": -0.007843018,
"special": false,
"text": " for"
},
{
"id": 12254,
"logprob": -1.7167969,
"special": false,
"text": " chicken"
},
{
"id": 611,
"logprob": -0.17053223,
"special": false,
"text": " on"
},
{
"id": 573,
"logprob": -0.7626953,
"special": false,
"text": " the"
},
{
"id": 8318,
"logprob": -0.02709961,
"special": false,
"text": " beach"
},
{
"id": 1,
"logprob": -0.20739746,
"special": true,
"text": "<eos>"
}
],
"top_tokens": null
},
"generated_text": "image result for chicken on the beach"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 20,
"prefill": [],
"seed": null,
"tokens": [
{
"id": 415,
"logprob": -0.04421997,
"special": false,
"text": " The"
},
{
"id": 12072,
"logprob": -0.13500977,
"special": false,
"text": " cow"
},
{
"id": 349,
"logprob": -0.06750488,
"special": false,
"text": " is"
},
{
"id": 6328,
"logprob": -0.6352539,
"special": false,
"text": " standing"
},
{
"id": 356,
"logprob": -0.16186523,
"special": false,
"text": " on"
},
{
"id": 272,
"logprob": -0.5078125,
"special": false,
"text": " the"
},
{
"id": 10305,
"logprob": -0.017913818,
"special": false,
"text": " beach"
},
{
"id": 304,
"logprob": -1.5205078,
"special": false,
"text": " and"
},
{
"id": 272,
"logprob": -0.029174805,
"special": false,
"text": " the"
},
{
"id": 13088,
"logprob": -0.003479004,
"special": false,
"text": " chicken"
},
{
"id": 349,
"logprob": -0.0035095215,
"special": false,
"text": " is"
},
{
"id": 6398,
"logprob": -0.3088379,
"special": false,
"text": " sitting"
},
{
"id": 356,
"logprob": -0.027755737,
"special": false,
"text": " on"
},
{
"id": 264,
"logprob": -0.31884766,
"special": false,
"text": " a"
},
{
"id": 17972,
"logprob": -0.047943115,
"special": false,
"text": " pile"
},
{
"id": 302,
"logprob": -0.0002925396,
"special": false,
"text": " of"
},
{
"id": 2445,
"logprob": -0.02935791,
"special": false,
"text": " money"
},
{
"id": 28723,
"logprob": -0.031219482,
"special": false,
"text": "."
},
{
"id": 32002,
"logprob": -0.00034475327,
"special": true,
"text": "<end_of_utterance>"
},
{
"id": 2,
"logprob": -1.1920929e-07,
"special": true,
"text": "</s>"
}
],
"top_tokens": null
},
"generated_text": " The cow is standing on the beach and the chicken is sitting on a pile of money."
}
23 changes: 23 additions & 0 deletions integration-tests/models/test_flash_pali_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ async def flash_pali_gemma(flash_pali_gemma_handle):
return flash_pali_gemma_handle.client


def get_chicken():
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"


def get_cow_beach():
with open("integration-tests/images/cow_beach.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
Expand All @@ -37,3 +43,20 @@ async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot):

assert response.generated_text == "beach"
assert response == response_snapshot


@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_pali_gemma_two_images(flash_pali_gemma, response_snapshot):
chicken = get_chicken()
cow_beach = get_cow_beach()
response = await flash_pali_gemma.generate(
f"caption![]({chicken})![]({cow_beach})\n",
max_new_tokens=20,
)
# Is PaliGemma not able to handle two separate images? At least we
# get output showing that both images are used.
assert (
response.generated_text == "image result for chicken on the beach"
), f"{repr(response.generated_text)}"
assert response == response_snapshot
21 changes: 21 additions & 0 deletions integration-tests/models/test_idefics.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ def get_chicken():
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"


def get_cow_beach():
with open("integration-tests/images/cow_beach.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"


@pytest.mark.asyncio
async def test_idefics(idefics, response_snapshot):
chicken = get_chicken()
Expand All @@ -39,6 +45,21 @@ async def test_idefics(idefics, response_snapshot):
assert response == response_snapshot


@pytest.mark.asyncio
@pytest.mark.private
async def test_idefics_two_images(idefics, response_snapshot):
chicken = get_chicken()
cow_beach = get_cow_beach()
response = await idefics.generate(
f"User:![]({chicken})![]({cow_beach})Where are the cow and chicken?<end_of_utterance> \nAssistant:",
max_new_tokens=20,
)
assert (
response.generated_text == " The cow and chicken are on a beach."
), f"{repr(response.generated_text)}"
assert response == response_snapshot


@pytest.mark.asyncio
async def test_idefics_load(idefics, generate_load, response_snapshot):
chicken = get_chicken()
Expand Down
23 changes: 23 additions & 0 deletions integration-tests/models/test_idefics2.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@ def get_chicken():
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"


def get_cow_beach():
with open("integration-tests/images/cow_beach.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"


@pytest.fixture(scope="module")
def flash_idefics2_next_handle(launcher):
with launcher(
Expand Down Expand Up @@ -38,6 +44,23 @@ async def test_flash_idefics2_next_simple(flash_idefics2_next, response_snapshot
assert response == response_snapshot


@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_idefics2_two_images(flash_idefics2_next, response_snapshot):
chicken = get_chicken()
cow_beach = get_cow_beach()
response = await flash_idefics2_next.generate(
f"User:![]({chicken})![]({cow_beach})Where are the cow and chicken?<end_of_utterance> \nAssistant:",
max_new_tokens=20,
)
assert (
response.generated_text
== " The cow is standing on the beach and the chicken is sitting on a pile of money."
), f"{repr(response.generated_text)}"
assert response.details.generated_tokens == 20
assert response == response_snapshot


@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_idefics2_next_all_params(flash_idefics2_next, response_snapshot):
Expand Down
57 changes: 30 additions & 27 deletions server/text_generation_server/models/vlm_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ def image_text_replacement(image_input, config, image_id) -> str:
num_features = get_number_of_features(height, width, config)
from loguru import logger

logger.info(f"Found {num_features} in image of resolution {height}x{width}")
logger.info(
f"Found {num_features} features in image of resolution {height}x{width}"
)
return "<image>" * num_features

elif config.model_type == "paligemma":
Expand Down Expand Up @@ -133,23 +135,41 @@ def filter(self, request_ids: List[int]):
def batch_tokenized_inputs(
cls, requests: Iterable[generate_pb2.Request], tokenizer, processor, config
):
# Process images first. We need all of them so that the processor
# can make the image splits the same size. And we need the final
# sizes to insert correct number of image tokens.
images = []
for r in requests:
for chunk in r.input_chunks.chunks:
chunk_type = chunk.WhichOneof("chunk")
if chunk_type == "text":
pass
elif chunk_type == "image":
image = Image.open(BytesIO(chunk.image.data))
if config.model_type == "llava_next":
images.append(image)
else:
images.append([image])
else:
raise RuntimeError(f"Invalid chunk type {chunk_type}")

if images:
image_inputs = processor.image_processor(images, return_tensors="pt")
else:
image_inputs = None

batch_inputs = []
image_inputs = []
max_truncation = 0
image_id = 0
for r in requests:
full_text = ""
image_id = 0
for chunk in r.input_chunks.chunks:
chunk_type = chunk.WhichOneof("chunk")
if chunk_type == "text":
full_text += chunk.text
elif chunk_type == "image":
image = Image.open(BytesIO(chunk.image.data))
image_input = processor.image_processor(image, return_tensors="pt")
full_text += image_text_replacement(image_input, config, image_id)
image_inputs.append(image_input)
else:
raise RuntimeError(f"Invalid chunk type {chunk_type}")
full_text += image_text_replacement(image_inputs, config, image_id)
image_id += 1

batch_inputs.append(full_text)
max_truncation = max(max_truncation, r.truncate)
Expand All @@ -160,24 +180,7 @@ def batch_tokenized_inputs(
max_length=max_truncation,
add_special_tokens=not config.model_type == "paligemma",
)["input_ids"]
if image_inputs:
image_input = image_inputs[0]
new_image_inputs = {
"pixel_values": torch.cat(
[img["pixel_values"] for img in image_inputs], dim=0
),
}
if "pixel_attention_mask" in image_input:
new_image_inputs["pixel_attention_mask"] = torch.cat(
[img["pixel_attention_mask"] for img in image_inputs], dim=0
)
if "image_sizes" in image_input:
new_image_inputs["image_sizes"] = torch.cat(
[img["image_sizes"] for img in image_inputs], dim=0
)
image_inputs = new_image_inputs
else:
image_inputs = None

return batch_tokenized_inputs, image_inputs

@classmethod
Expand Down

0 comments on commit 7d3439f

Please sign in to comment.