Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Support Pixtral models in the HF Transformers format #9036

Merged
merged 15 commits into from
Oct 18, 2024
Prev Previous commit
Next Next commit
Fix new_token_ids
  • Loading branch information
mgoin committed Oct 18, 2024
commit a8c0f3540d4f909b216f745470b7fa4926d4fff2
2 changes: 2 additions & 0 deletions vllm/model_executor/layers/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,8 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
lambda: nn.ReLU(),
"relu2":
lambda: ReLUSquaredActivation(),
"silu":
lambda: nn.SiLU(),
"quick_gelu":
lambda: QuickGELU(),
})
Expand Down
68 changes: 52 additions & 16 deletions vllm/model_executor/models/pixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,7 +690,6 @@ def input_processor_for_pixtral_hf(
if multi_modal_data is None or "image" not in multi_modal_data:
return inputs

tokenizer = cached_get_tokenizer(model_config.tokenizer)
processor = cached_get_processor(model_config.model)

image_data = multi_modal_data["image"]
Expand All @@ -699,31 +698,66 @@ def input_processor_for_pixtral_hf(
elif not is_list_of(image_data, Image.Image):
raise TypeError(f"Invalid image type: {type(image_data)}")

replace_strings = []
new_prompt = inputs.get("prompt")
new_token_ids = inputs["prompt_token_ids"]

# Update new_prompt if present
if new_prompt:
replace_strings = []
for image in image_data:
w, h = image.size

(num_width_tokens,
num_height_tokens) = get_pixtral_hf_image_feature_size(
hf_config, image_width=w, image_height=h)

replace_tokens = [[processor.image_token] * num_width_tokens +
[processor.image_break_token]
] * num_height_tokens
# Flatten list
replace_tokens = [
item for sublist in replace_tokens for item in sublist
]
replace_tokens[-1] = processor.image_end_token
replace_str = "".join(replace_tokens)
replace_strings.append(replace_str)
new_prompt = new_prompt.replace(processor.image_token,
"<placeholder>", 1)

while "<placeholder>" in new_prompt:
replace_str = replace_strings.pop(0)
new_prompt = new_prompt.replace("<placeholder>", replace_str, 1)

# Update new_token_ids
image_token_id = 10
image_break_id = 12
image_end_id = 13
placeholder_token_id = -999
replace_tokens_list = []
for image in image_data:
w, h = image.size

num_width_tokens, num_height_tokens = get_pixtral_hf_image_feature_size(
hf_config, image_width=w, image_height=h)

replace_tokens = [[processor.image_token] * num_width_tokens +
[processor.image_break_token]] * num_height_tokens
replace_tokens = [[image_token_id] * num_width_tokens +
[image_break_id]] * num_height_tokens
# Flatten list
replace_tokens = [
item for sublist in replace_tokens for item in sublist
]
replace_tokens[-1] = processor.image_end_token
replace_str = "".join(replace_tokens)
replace_strings.append(replace_str)
new_prompt = new_prompt.replace(processor.image_token, "<placeholder>",
1)

while "<placeholder>" in new_prompt:
replace_str = replace_strings.pop(0)
new_prompt = new_prompt.replace("<placeholder>", replace_str, 1)

new_token_ids = tokenizer(new_prompt)["input_ids"]
replace_tokens[-1] = image_end_id
replace_tokens_list.append(replace_tokens)
# Replace image id with placeholder id
next_image_index = new_token_ids.index(image_token_id)
new_token_ids[next_image_index] = placeholder_token_id

while placeholder_token_id in new_token_ids:
replace_tokens = replace_tokens_list.pop(0)
next_image_index = new_token_ids.index(placeholder_token_id)
prefix = new_token_ids[:next_image_index]
postfix = new_token_ids[next_image_index + 1:]
new_token_ids = prefix + replace_tokens + postfix

# NOTE: Create a defensive copy of the original inputs
return token_inputs(prompt_token_ids=new_token_ids,
Expand Down Expand Up @@ -958,7 +992,9 @@ def forward(
# pass images through initial convolution independently
dtype = next(self.parameters()).dtype
patch_embeds_list = [
self.patch_conv(img.unsqueeze(0).to(dtype)) for img in pixel_values
self.patch_conv(
img.reshape(-1, img.shape[-3], img.shape[-2],
img.shape[-1]).to(dtype)) for img in pixel_values
]

# flatten to a single sequence
Expand Down