Skip to content

Commit

Permalink
[Model][Pixtral] Optimizations for input_processor_for_pixtral_hf (vl…
Browse files Browse the repository at this point in the history
  • Loading branch information
mgoin authored Oct 19, 2024
1 parent 263d8ee commit 8e3e7f2
Showing 1 changed file with 41 additions and 40 deletions.
81 changes: 41 additions & 40 deletions vllm/model_executor/models/pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,63 +701,64 @@ def input_processor_for_pixtral_hf(
new_prompt = inputs.get("prompt")
new_token_ids = inputs["prompt_token_ids"]

image_token = processor.image_token
image_break_token = processor.image_break_token
image_end_token = processor.image_end_token

# Update new_prompt if present
if new_prompt:
replace_strings = []
for image in image_data:
w, h = image.size
parts = new_prompt.split(image_token)
assert len(parts) - 1 == len(image_data)
new_parts = [parts[0]] # Start with the part before any image tokens

for image, next_part in zip(image_data, parts[1:]):
w, h = image.size
(num_width_tokens,
num_height_tokens) = get_pixtral_hf_image_feature_size(
hf_config, image_width=w, image_height=h)

replace_tokens = [[processor.image_token] * num_width_tokens +
[processor.image_break_token]
] * num_height_tokens
# Flatten list
replace_tokens = [
item for sublist in replace_tokens for item in sublist
replace_tokens = [image_token] * num_width_tokens + [
image_break_token
]
replace_tokens[-1] = processor.image_end_token
replace_str = "".join(replace_tokens)
replace_strings.append(replace_str)
new_prompt = new_prompt.replace(processor.image_token,
"<placeholder>", 1)
replace_tokens = replace_tokens * num_height_tokens
replace_tokens[-1] = image_end_token

while "<placeholder>" in new_prompt:
replace_str = replace_strings.pop(0)
new_prompt = new_prompt.replace("<placeholder>", replace_str, 1)
new_parts.append("".join(replace_tokens))
new_parts.append(next_part)

new_prompt = "".join(new_parts)

# Update new_token_ids
image_token_id = 10
image_break_id = 12
image_end_id = 13
convert_tokens_to_ids = processor.tokenizer.convert_tokens_to_ids
image_token_id = convert_tokens_to_ids(image_token)
image_break_id = convert_tokens_to_ids(image_break_token)
image_end_id = convert_tokens_to_ids(image_end_token)
placeholder_token_id = -999
# Find all image token indices at once
placeholder_indices = [
idx for idx, token_id in enumerate(new_token_ids)
if token_id == image_token_id
]
assert len(placeholder_indices) == len(image_data)
replace_tokens_list = []
for image in image_data:
w, h = image.size
for placeholder_idx, image in zip(placeholder_indices, image_data):
new_token_ids[placeholder_idx] = placeholder_token_id

num_width_tokens, num_height_tokens = get_pixtral_hf_image_feature_size(
hf_config, image_width=w, image_height=h)
w, h = image.size
(num_width_tokens,
num_height_tokens) = get_pixtral_hf_image_feature_size(hf_config,
image_width=w,
image_height=h)

replace_tokens = [[image_token_id] * num_width_tokens +
[image_break_id]] * num_height_tokens
# Flatten list
replace_tokens = [
item for sublist in replace_tokens for item in sublist
]
replace_tokens = [image_token_id] * num_width_tokens + [image_break_id]
replace_tokens = replace_tokens * num_height_tokens
replace_tokens[-1] = image_end_id
replace_tokens_list.append(replace_tokens)
# Replace image id with placeholder id
next_image_index = new_token_ids.index(image_token_id)
new_token_ids[next_image_index] = placeholder_token_id

while placeholder_token_id in new_token_ids:
replace_tokens = replace_tokens_list.pop(0)
next_image_index = new_token_ids.index(placeholder_token_id)
prefix = new_token_ids[:next_image_index]
postfix = new_token_ids[next_image_index + 1:]
new_token_ids = prefix + replace_tokens + postfix

# Backward iteration for replacement without affecting known indices
for placeholder_idx, replace_tokens in zip(reversed(placeholder_indices),
reversed(replace_tokens_list)):
new_token_ids[placeholder_idx:placeholder_idx + 1] = replace_tokens

# NOTE: Create a defensive copy of the original inputs
return token_inputs(prompt_token_ids=new_token_ids,
Expand Down

0 comments on commit 8e3e7f2

Please sign in to comment.