diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index d09cbe5ca02e9..b07ac5baecda9 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -701,63 +701,64 @@ def input_processor_for_pixtral_hf( new_prompt = inputs.get("prompt") new_token_ids = inputs["prompt_token_ids"] + image_token = processor.image_token + image_break_token = processor.image_break_token + image_end_token = processor.image_end_token + # Update new_prompt if present if new_prompt: - replace_strings = [] - for image in image_data: - w, h = image.size + parts = new_prompt.split(image_token) + assert len(parts) - 1 == len(image_data) + new_parts = [parts[0]] # Start with the part before any image tokens + for image, next_part in zip(image_data, parts[1:]): + w, h = image.size (num_width_tokens, num_height_tokens) = get_pixtral_hf_image_feature_size( hf_config, image_width=w, image_height=h) - replace_tokens = [[processor.image_token] * num_width_tokens + - [processor.image_break_token] - ] * num_height_tokens - # Flatten list - replace_tokens = [ - item for sublist in replace_tokens for item in sublist + replace_tokens = [image_token] * num_width_tokens + [ + image_break_token ] - replace_tokens[-1] = processor.image_end_token - replace_str = "".join(replace_tokens) - replace_strings.append(replace_str) - new_prompt = new_prompt.replace(processor.image_token, - "", 1) + replace_tokens = replace_tokens * num_height_tokens + replace_tokens[-1] = image_end_token - while "" in new_prompt: - replace_str = replace_strings.pop(0) - new_prompt = new_prompt.replace("", replace_str, 1) + new_parts.append("".join(replace_tokens)) + new_parts.append(next_part) + + new_prompt = "".join(new_parts) # Update new_token_ids - image_token_id = 10 - image_break_id = 12 - image_end_id = 13 + convert_tokens_to_ids = processor.tokenizer.convert_tokens_to_ids + image_token_id = convert_tokens_to_ids(image_token) + image_break_id = convert_tokens_to_ids(image_break_token) + image_end_id = convert_tokens_to_ids(image_end_token) placeholder_token_id = -999 + # Find all image token indices at once + placeholder_indices = [ + idx for idx, token_id in enumerate(new_token_ids) + if token_id == image_token_id + ] + assert len(placeholder_indices) == len(image_data) replace_tokens_list = [] - for image in image_data: - w, h = image.size + for placeholder_idx, image in zip(placeholder_indices, image_data): + new_token_ids[placeholder_idx] = placeholder_token_id - num_width_tokens, num_height_tokens = get_pixtral_hf_image_feature_size( - hf_config, image_width=w, image_height=h) + w, h = image.size + (num_width_tokens, + num_height_tokens) = get_pixtral_hf_image_feature_size(hf_config, + image_width=w, + image_height=h) - replace_tokens = [[image_token_id] * num_width_tokens + - [image_break_id]] * num_height_tokens - # Flatten list - replace_tokens = [ - item for sublist in replace_tokens for item in sublist - ] + replace_tokens = [image_token_id] * num_width_tokens + [image_break_id] + replace_tokens = replace_tokens * num_height_tokens replace_tokens[-1] = image_end_id replace_tokens_list.append(replace_tokens) - # Replace image id with placeholder id - next_image_index = new_token_ids.index(image_token_id) - new_token_ids[next_image_index] = placeholder_token_id - - while placeholder_token_id in new_token_ids: - replace_tokens = replace_tokens_list.pop(0) - next_image_index = new_token_ids.index(placeholder_token_id) - prefix = new_token_ids[:next_image_index] - postfix = new_token_ids[next_image_index + 1:] - new_token_ids = prefix + replace_tokens + postfix + + # Backward iteration for replacement without affecting known indices + for placeholder_idx, replace_tokens in zip(reversed(placeholder_indices), + reversed(replace_tokens_list)): + new_token_ids[placeholder_idx:placeholder_idx + 1] = replace_tokens # NOTE: Create a defensive copy of the original inputs return token_inputs(prompt_token_ids=new_token_ids,