Skip to content

[Model][VLM] Add Qwen2-VL model support #7905

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 44 commits into from
Sep 11, 2024

Conversation

fyabc
Copy link
Contributor

@fyabc fyabc commented Aug 27, 2024

This PR adding support for Qwen2-VL model.

FIX #8139
FIX #8281

Requirements

  • This PR requires transformers with this PR merged and this bugfix PR merged (You can install it via pip install git+https://github.com/huggingface/transformers@21fac7abba2a37fae86106f87fcf9974fd1e3830).
  • NOTE: Current latest transformers version have a bug, so you should install a develop version as above now.
  • For transformers>=4.45, please install vLLM from source.
  • For transformers>=4.45, please install vllm>=0.6.3.

Optional Requirements

  • When constructing LLM inputs, we recommend using our helper package qwen-vl-utils to preprocess multimodal content correctly (qwen-vl-utils is not a part of this PR).

Example Usage

from PIL import Image
from transformers import AutoProcessor
from vllm import LLM, SamplingParams
from qwen_vl_utils import process_vision_info

MODEL_PATH = 'Qwen/Qwen2-VL-7B-Instruct'
IMAGE_PATH = '/path/to/image.jpg'
VIDEO_PATH = '/path/to/video.mp4'

llm = LLM(
    model=MODEL_PATH,
    limit_mm_per_prompt={'image': 10, 'video': 10},
)

sampling_params = SamplingParams(
    temperature=0.1, top_p=0.001, repetition_penalty=1.05, max_tokens=256,
    stop_token_ids=[],
)

messages = [
    {'role': 'system', 'content': 'You are a helpful assistant.'},
    {'role': 'user', 'content': [
        {
            'type': 'image',
            'image': IMAGE_PATH,

            # min_pixels & max_pixels are optional
            'max_pixels': 12845056,
        },

        # You can also pass one or more videos:
        # {
        #     'type': 'video',
        #     'video': VIDEO_PATH,
        # }

        {
            'type': 'text',
            'text': 'What does this diagram illustrate?',
        },
    ]},
]

processor = AutoProcessor.from_pretrained(MODEL_PATH)
prompt = processor.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=True,
)
image_inputs, video_inputs = process_vision_info(messages)

mm_data = {}
if image_inputs is not None:
    mm_data['image'] = image_inputs
if video_inputs is not None:
    mm_data['video'] = video_inputs

llm_inputs = {
    'prompt': prompt,
    'multi_modal_data': mm_data,
}

outputs = llm.generate([llm_inputs], sampling_params=sampling_params)
generated_text = outputs[0].outputs[0].text

print(generated_text)

Notes

Here are some important notes about this PR:

  1. Qwen2-VL uses rotary embedding with multimodal sections (mrope) (see vllm/model_executor/layers/rotary_embedding.py for more details). This rotary embedding requires the input positions to be a tensor of shape (3, seq_len) (instead of (seq_len,) in common case).

    1. To support this feature, we add a new _mrope_position_delta (with type Optional[int]) attribute into vllm.sequence.SequenceData (this attribute is used to compute mrope_input_positions in each decoding step). (If reviewers have a better solution, please comment in this PR)
    2. We also change model_runner.py to compute the mrope_input_positions when the model uses mrope. Other model runners should also follow this logic, I think this can be done in another PR (I will add this part if reviewers thinks it needs to be implemented in this PR).
  2. Qwen2-VL uses flash-attn==2.6.1 (instead of vllm-flash-attn==2.6.1) to compute vision attention (see the commented line 36 in vllm/model_executor/models/qwen2_vl.py). Current vllm-flash-attn version will output NaN logits value, and I am still debugging this bug.

    1. UPDATE 2024.09.06: Add xformers backend as a fallback implementation of Qwen2VisionAttention, so there is no need to add flash-attn into project requirements file.
  3. Qwen2-VL supports both image and video inputs. To support this feature, we add a video multimodal plugin (see vllm/multimodal/video.py for more details).

  4. OpenAI-compatible server

    1. Currently, vllm.entrypoints.openai.api_server uses a model-independent multimodal data fetcher (e.g. vllm.multimodal.utils.async_get_and_parse_image), so vision smart resizing logic in qwen-vl-utils cannot be applied now. I think its good to create another PR to fix it later.
  5. Multiple modalities support details

    Since Qwen2-VL support two modalities (images and videos), we should handle some special cases as below:

    # 1. A batch with two samples, sample 1 contains images, sample 2 contains videos
    llm.generate([
        {
            "prompt": "XXX",
            "multi_modal_data": {
                "image": ...
            }
        },
        {
            "prompt": "XXX",
            "multi_modal_data": {
                "video": ...
            }
        }
    ])
    
    # 2. A single sample with both images and videos
    llm.generate([
        {
            "prompt": "XXX",
            "multi_modal_data": {
                "image": ...,
                "video": ...
            }
        }
    ])

    So I remove the key same check in vllm.multimodal.base.MultiModalInputs.batch() method, since different samples may returns different modality keys.

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which consists a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of default ones by unblocking the steps in your fast-check build on Buildkite UI.

Once the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge).

To run full CI, you can do one of these:

  • Comment /ready on the PR
  • Add ready label to the PR
  • Enable auto-merge.

🚀

@DarkLight1337
Copy link
Member

DarkLight1337 commented Aug 29, 2024

Thanks for implementing this (and sorry for the delayed response)! Since this PR not only introduces a new modality (video) but also involves the first model to accept multiple modalities (excluding text), I would like to merge #7559 first to verify that vLLM can handle video inputs properly.

In the meantime, can you fix the CI failures?

@fyabc
Copy link
Contributor Author

fyabc commented Aug 29, 2024

Thanks for implementing this (and sorry for the delayed response)! Since this PR not only introduces a new modality (video) but also involves the first model to accept multiple modalities (excluding text), I would like to merge #7559 first to verify that vLLM can handle video inputs properly.

In the meantime, can you fix the CI failures?

image
Hi @DarkLight1337 , these mypy errors seems not belongs to this PR, should I also fix them?

Copy link
Member

@ywang96 ywang96 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fyabc Thank you for contributing to vLLM! I took a brief took and left a first round of review. Please take a look.

As @DarkLight1337 mentioned, we might want to wait for #7559 to be merged first because as we're going to have a model that supports a mix of modalities, we want to be careful with API changes.

Comment on lines 626 to 658
# special processing for mrope position deltas.
if self.runner.model_is_mrope:
image_grid_thw = mm_kwargs.get("image_grid_thw", None)
video_grid_thw = mm_kwargs.get("video_grid_thw", None)
assert image_grid_thw is not None or video_grid_thw is not None, \
"mrope embedding type requires multi-modal input mapper returns 'image_grid_thw' or 'video_grid_thw'."

hf_config = self.runner.model_config.hf_config

from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding

inter_data.mrope_input_positions = [None] * inter_data.n_seqs
for seq_idx in range(inter_data.n_seqs):
seq_data = seq_group_metadata.seq_data[
inter_data.seq_ids[seq_idx]]
token_ids = seq_data.get_token_ids()

mrope_input_positions, mrope_position_delta = MRotaryEmbedding.get_input_positions(
token_ids,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
image_token_id=hf_config.image_token_id,
video_token_id=hf_config.video_token_id,
vision_start_token_id=hf_config.vision_start_token_id,
vision_end_token_id=hf_config.vision_end_token_id,
spatial_merge_size=hf_config.vision_config.
spatial_merge_size,
context_len=inter_data.context_lens[seq_idx],
)

seq_data.mrope_position_delta = mrope_position_delta
inter_data.mrope_input_positions[
seq_idx] = mrope_input_positions
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm okay with us doing this at the model runner level and I'm honestly sure if there's a better place to apply mrope. What's your thought on this? @WoosukKwon

@DarkLight1337
Copy link
Member

image Hi @DarkLight1337 , these mypy errors seems not belongs to this PR, should I also fix them?

Can you merge from main first? It fixes some of the mypy errors which might apply here.

@fyabc
Copy link
Contributor Author

fyabc commented Aug 29, 2024

Hi @DarkLight1337 @ywang96 , I have updated this PR based on your review comments, please check it again.
I also add some notes about multiple modalities in the PR overview.

@DragonFive
Copy link

@fyabc Hi, can this patch support mutiple images in one prompt like follows:

Compute the value of the expression in the image below <image_1>\nby using the emoji equations in the following images <image_2> <image_3> <image_4> <image_5> Only answer specific numerical values.

@fyabc
Copy link
Contributor Author

fyabc commented Aug 30, 2024

@fyabc Hi, can this patch support mutiple images in one prompt like follows:

Compute the value of the expression in the image below <image_1>\nby using the emoji equations in the following images <image_2> <image_3> <image_4> <image_5> Only answer specific numerical values.

Hi @DragonFive , you can pass multiple images into a single prompt like this:

messages = [
    {
        "role": "user",
        "content": [
            {"type": "image", "image": "file:///path/to/image1.jpg"},
            {"type": "image", "image": "file:///path/to/image2.jpg"},
            {"type": "text", "text": "Identify the similarities between these images."},
        ],
    }
]

See "Multi image inference" section of our README for more details.

@5Elza5
Copy link

5Elza5 commented Sep 17, 2024

Ahhhh I could have run vLLM with Qwen2-VL-7B as:

CUDA_VISIBLE_DEVICES=3 python -m vllm.entrypoints.openai.api_server --limit-mm-per-prompt image=30 --host 0.0.0.0 --port 9999 --served-model-name EraX-VL-V1 --model ./EraX-VL-7B

I used the code from Qwen2 git (https://github.com/QwenLM/Qwen2-VL)

import cv2
import matplotlib.pyplot as plt
from PIL import Image

import uuid, base64

# Prepare base64 image
test_image1 = './samples/bill-1.png'

with open(test_image1, "rb") as f:
    encoded_image = base64.b64encode(f.read())

encoded_image_text = encoded_image.decode('utf-8')
base64_qwen = f"data:image;base64,{encoded_image_text}"

# Run
from openai import OpenAI

# Set OpenAI's API key and API base to use vLLM's API server.
openai_api_key = "EMPTY"
openai_api_base = "http://localhost:9999/v1"

client = OpenAI(
    api_key=openai_api_key,
    base_url=openai_api_base,
)

prompt = "What is the content of the image?"

chat_response = client.chat.completions.create(
    model="EraX-VL-V1",
    temperature=0.2,
    messages=[
        {"role": "system", "content": "You are a helpful assistant."},
        {
            "role": "user",
            "content": [
                {
                    "type": "image",
                    "image": base64_qwen,
                },
                {
                    "type": "text", 
                    "text": prompt
                },
            ],
        },
    ],
)

ERROR:

   1038         err.response.read()
   1040     log.debug("Re-raising status error")
-> 1041     raise self._make_status_error_from_response(err.response) from None
   1043 return self._process_response(
   1044     cast_to=cast_to,
   1045     options=options,
   (...)
   1049     retries_taken=options.get_max_retries(self.max_retries) - retries,
   1050 )

BadRequestError: Error code: 400 - {'object': 'error', 'message': 'Unknown part type: image', 'type': 'BadRequestError', 'param': None, 'code': 400}

Any hint please.

Thanks, Steve

try this:

"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{base64_image}"
from here:
https://platform.openai.com/docs/guides/vision/uploading-base-64-encoded-images

@yuanjietu
Copy link

Hi, I am running this and got the same error as #8281. Could someone help me with this? Thank you!

Unrecognized keys in rope_scaling for 'rope_type'='default': {'mrope_section'}

AssertionError Traceback (most recent call last)
/tmp/ipykernel_32014/2600037648.py in
9 del config.rope_scaling['mrope_section']
10
---> 11 llm = LLM(
12 model=MODEL_PATH,
13 limit_mm_per_prompt={'image': 10, 'video': 10},

~/.local/lib/python3.10/site-packages/vllm/entrypoints/llm.py in init(self, model, tokenizer, tokenizer_mode, skip_tokenizer_init, trust_remote_code, tensor_parallel_size, dtype, quantization, revision, tokenizer_revision, seed, gpu_memory_utilization, swap_space, cpu_offload_gb, enforce_eager, max_context_len_to_capture, max_seq_len_to_capture, disable_custom_all_reduce, disable_async_output_proc, **kwargs)
176 **kwargs,
177 )
--> 178 self.llm_engine = LLMEngine.from_engine_args(
179 engine_args, usage_context=UsageContext.LLM_CLASS)
180 self.request_counter = Counter()

~/.local/lib/python3.10/site-packages/vllm/engine/llm_engine.py in from_engine_args(cls, engine_args, usage_context, stat_loggers)
545 """Creates an LLM engine from the engine arguments."""
546 # Create the engine configs.
--> 547 engine_config = engine_args.create_engine_config()
548 executor_class = cls._get_executor_cls(engine_config)
549 # Create the LLM engine.

~/.local/lib/python3.10/site-packages/vllm/engine/arg_utils.py in create_engine_config(self)
842
843 device_config = DeviceConfig(device=self.device)
--> 844 model_config = self.create_model_config()
845
846 if model_config.is_multimodal_model:

~/.local/lib/python3.10/site-packages/vllm/engine/arg_utils.py in create_model_config(self)
780
781 def create_model_config(self) -> ModelConfig:
--> 782 return ModelConfig(
783 model=self.model,
784 tokenizer=self.tokenizer,

~/.local/lib/python3.10/site-packages/vllm/config.py in init(self, model, tokenizer, tokenizer_mode, trust_remote_code, dtype, seed, revision, code_revision, rope_scaling, rope_theta, tokenizer_revision, max_model_len, spec_target_max_model_len, quantization, quantization_param_path, enforce_eager, max_context_len_to_capture, max_seq_len_to_capture, max_logprobs, disable_sliding_window, skip_tokenizer_init, served_model_name, limit_mm_per_prompt, use_async_output_proc, override_neuron_config, config_format)
225 self.disable_sliding_window = True
226
--> 227 self.max_model_len = _get_and_verify_max_len(
228 hf_config=self.hf_text_config,
229 max_model_len=max_model_len,

~/.local/lib/python3.10/site-packages/vllm/config.py in _get_and_verify_max_len(hf_config, max_model_len, disable_sliding_window, sliding_window_len, spec_target_max_model_len)
1745 scaling_factor = 1
1746 else:
-> 1747 assert "factor" in rope_scaling
1748 scaling_factor = rope_scaling["factor"]
1749 if rope_type == "yarn":

AssertionError:

@DarkLight1337
Copy link
Member

See my comment above: #7905 (comment)

@exceedzhang
Copy link

Hi, I am running this and got the same error as #8281. Could someone help me with this? Thank you!

Unrecognized keys in rope_scaling for 'rope_type'='default': {'mrope_section'}

AssertionError Traceback (most recent call last) /tmp/ipykernel_32014/2600037648.py in 9 del config.rope_scaling['mrope_section'] 10 ---> 11 llm = LLM( 12 model=MODEL_PATH, 13 limit_mm_per_prompt={'image': 10, 'video': 10},

~/.local/lib/python3.10/site-packages/vllm/entrypoints/llm.py in init(self, model, tokenizer, tokenizer_mode, skip_tokenizer_init, trust_remote_code, tensor_parallel_size, dtype, quantization, revision, tokenizer_revision, seed, gpu_memory_utilization, swap_space, cpu_offload_gb, enforce_eager, max_context_len_to_capture, max_seq_len_to_capture, disable_custom_all_reduce, disable_async_output_proc, **kwargs) 176 **kwargs, 177 ) --> 178 self.llm_engine = LLMEngine.from_engine_args( 179 engine_args, usage_context=UsageContext.LLM_CLASS) 180 self.request_counter = Counter()

~/.local/lib/python3.10/site-packages/vllm/engine/llm_engine.py in from_engine_args(cls, engine_args, usage_context, stat_loggers) 545 """Creates an LLM engine from the engine arguments.""" 546 # Create the engine configs. --> 547 engine_config = engine_args.create_engine_config() 548 executor_class = cls._get_executor_cls(engine_config) 549 # Create the LLM engine.

~/.local/lib/python3.10/site-packages/vllm/engine/arg_utils.py in create_engine_config(self) 842 843 device_config = DeviceConfig(device=self.device) --> 844 model_config = self.create_model_config() 845 846 if model_config.is_multimodal_model:

~/.local/lib/python3.10/site-packages/vllm/engine/arg_utils.py in create_model_config(self) 780 781 def create_model_config(self) -> ModelConfig: --> 782 return ModelConfig( 783 model=self.model, 784 tokenizer=self.tokenizer,

~/.local/lib/python3.10/site-packages/vllm/config.py in init(self, model, tokenizer, tokenizer_mode, trust_remote_code, dtype, seed, revision, code_revision, rope_scaling, rope_theta, tokenizer_revision, max_model_len, spec_target_max_model_len, quantization, quantization_param_path, enforce_eager, max_context_len_to_capture, max_seq_len_to_capture, max_logprobs, disable_sliding_window, skip_tokenizer_init, served_model_name, limit_mm_per_prompt, use_async_output_proc, override_neuron_config, config_format) 225 self.disable_sliding_window = True 226 --> 227 self.max_model_len = _get_and_verify_max_len( 228 hf_config=self.hf_text_config, 229 max_model_len=max_model_len,

~/.local/lib/python3.10/site-packages/vllm/config.py in _get_and_verify_max_len(hf_config, max_model_len, disable_sliding_window, sliding_window_len, spec_target_max_model_len) 1745 scaling_factor = 1 1746 else: -> 1747 assert "factor" in rope_scaling 1748 scaling_factor = rope_scaling["factor"] 1749 if rope_type == "yarn":

AssertionError:

huggingface/transformers#33401

@YuanLiuuuuuu
Copy link

YuanLiuuuuuu commented Sep 26, 2024

Unrecognized keys in rope_scaling for 'rope_type'='default': {'mrope_section'} Traceback (most recent call last): File "/workspace/lite/test1.py", line 10, in llm = LLM( File "/workspace/lite/venv/lib/python3.10/site-packages/vllm/entrypoints/llm.py", line 178, in init self.llm_engine = LLMEngine.from_engine_args( File "/workspace/lite/venv/lib/python3.10/site-packages/vllm/engine/llm_engine.py", line 547, in from_engine_args engine_config = engine_args.create_engine_config() File "/workspace/lite/venv/lib/python3.10/site-packages/vllm/engine/arg_utils.py", line 844, in create_engine_config model_config = self.create_model_config() File "/workspace/lite/venv/lib/python3.10/site-packages/vllm/engine/arg_utils.py", line 782, in create_model_config return ModelConfig( File "/workspace/lite/venv/lib/python3.10/site-packages/vllm/config.py", line 227, in init self.max_model_len = _get_and_verify_max_len( File "/workspace/lite/venv/lib/python3.10/site-packages/vllm/config.py", line 1747, in _get_and_verify_max_len assert "factor" in rope_scaling AssertionError

@AlexanderChen1989 please make sure you have installed this particular version of transformers pip install git+https://github.com/huggingface/transformers@21fac7abba2a37fae86106f87fcf9974fd1e3830

This version of transformer will raise the following error:

ModuleNotFoundError: No module named 'transformers.models.mllama'

@DarkLight1337
Copy link
Member

DarkLight1337 commented Sep 26, 2024

T

Unrecognized keys in rope_scaling for 'rope_type'='default': {'mrope_section'} Traceback (most recent call last): File "/workspace/lite/test1.py", line 10, in llm = LLM( File "/workspace/lite/venv/lib/python3.10/site-packages/vllm/entrypoints/llm.py", line 178, in init self.llm_engine = LLMEngine.from_engine_args( File "/workspace/lite/venv/lib/python3.10/site-packages/vllm/engine/llm_engine.py", line 547, in from_engine_args engine_config = engine_args.create_engine_config() File "/workspace/lite/venv/lib/python3.10/site-packages/vllm/engine/arg_utils.py", line 844, in create_engine_config model_config = self.create_model_config() File "/workspace/lite/venv/lib/python3.10/site-packages/vllm/engine/arg_utils.py", line 782, in create_model_config return ModelConfig( File "/workspace/lite/venv/lib/python3.10/site-packages/vllm/config.py", line 227, in init self.max_model_len = _get_and_verify_max_len( File "/workspace/lite/venv/lib/python3.10/site-packages/vllm/config.py", line 1747, in _get_and_verify_max_len assert "factor" in rope_scaling AssertionError

@AlexanderChen1989 please make sure you have installed this particular version of transformers pip install git+https://github.com/huggingface/transformers@21fac7abba2a37fae86106f87fcf9974fd1e3830

This version of transformer will raise the following error:

ModuleNotFoundError: No module named 'transformers.models.mllama'

The current version of vLLM requires transformers>=4.45. Qwen2-VL has only just been made compatible with transformers>=4.45 in vLLM, so you'll have to install vLLM from source.

@chenzhengda
Copy link

@fyabc Hi, I've noticed that in the Qwen2 VL chat template, there is no '\n' after <|vision_end|>, but there is one when launched through the vllm API server. This seems to be a bug.

@SepehrV
Copy link

SepehrV commented Oct 3, 2024

Unrecognized keys in rope_scaling for 'rope_type'='default': {'mrope_section'} Traceback (most recent call last): File "/workspace/lite/test1.py", line 10, in llm = LLM( File "/workspace/lite/venv/lib/python3.10/site-packages/vllm/entrypoints/llm.py", line 178, in init self.llm_engine = LLMEngine.from_engine_args( File "/workspace/lite/venv/lib/python3.10/site-packages/vllm/engine/llm_engine.py", line 547, in from_engine_args engine_config = engine_args.create_engine_config() File "/workspace/lite/venv/lib/python3.10/site-packages/vllm/engine/arg_utils.py", line 844, in create_engine_config model_config = self.create_model_config() File "/workspace/lite/venv/lib/python3.10/site-packages/vllm/engine/arg_utils.py", line 782, in create_model_config return ModelConfig( File "/workspace/lite/venv/lib/python3.10/site-packages/vllm/config.py", line 227, in init self.max_model_len = _get_and_verify_max_len( File "/workspace/lite/venv/lib/python3.10/site-packages/vllm/config.py", line 1747, in _get_and_verify_max_len assert "factor" in rope_scaling AssertionError

@AlexanderChen1989 please make sure you have installed this particular version of transformers pip install git+https://github.com/huggingface/transformers@21fac7abba2a37fae86106f87fcf9974fd1e3830

this transformers version is not compatible with the latest VLLM anymore. (mllama missing).

I tried this using transformers after this fix huggingface/transformers#33753 but vllm is still throwing assert "factor" in rope_scaling

@DarkLight1337
Copy link
Member

Yeah, you need to install vLLM from source to fix the problem now. Please refer to the top post in this thread.

@DarkLight1337 DarkLight1337 mentioned this pull request Oct 5, 2024
1 task
@fyabc
Copy link
Contributor Author

fyabc commented Oct 8, 2024

@fyabc Hi, I've noticed that in the Qwen2 VL chat template, there is no '\n' after <|vision_end|>, but there is one when launched through the vllm API server. This seems to be a bug.

@chenzhengda Hi, by default all mm placeholders are joined with "\n" separator (see vllm.entrypoints.chat_utils._parse_chat_message_content_parts for detailed implementation). It seems that we need to refactor chat_utils.py to fix this bug.
@ywang96 @DarkLight1337 Please also take a look at this problem and check my comments.

@DarkLight1337
Copy link
Member

DarkLight1337 commented Oct 8, 2024

@fyabc Hi, I've noticed that in the Qwen2 VL chat template, there is no '\n' after <|vision_end|>, but there is one when launched through the vllm API server. This seems to be a bug.

@chenzhengda Hi, by default all mm placeholders are joined with "\n" separator (see vllm.entrypoints.chat_utils._parse_chat_message_content_parts for detailed implementation). It seems that we need to refactor chat_utils.py to fix this bug.
@ywang96 @DarkLight1337 Please also take a look at this problem and check my comments.

Thanks for pointing that out. We have this on our multimodality plan but haven't gotten around to implementing it yet. Since many HF chat templates do not specify how to combine placeholder multimodal tokens (like <image>) together, we hardcode this to inserting newlines for now. The semantics of HF chat template and preprocessing differs between models so we need to have more thoughts on this. An RFC to discuss this in detail would be nice, WDYT @ywang96 ?

@seanll-ke
Copy link

Does qwen2vl deployed using vllm support function call?

@baisong666
Copy link

Does qwen2vl deployed using vllm support function call?

+1

@whyiug
Copy link
Contributor

whyiug commented Oct 16, 2024

@fyabc Hi, I've noticed that in the Qwen2 VL chat template, there is no '\n' after <|vision_end|>, but there is one when launched through the vllm API server. This seems to be a bug.

@chenzhengda Hi, by default all mm placeholders are joined with "\n" separator (see vllm.entrypoints.chat_utils._parse_chat_message_content_parts for detailed implementation). It seems that we need to refactor chat_utils.py to fix this bug.
@ywang96 @DarkLight1337 Please also take a look at this problem and check my comments.

Thanks for pointing that out. We have this on our multimodality plan but haven't gotten around to implementing it yet. Since many HF chat templates do not specify how to combine placeholder multimodal tokens (like <image>) together, we hardcode this to inserting newlines for now. The semantics of HF chat template and preprocessing differs between models so we need to have more thoughts on this. An RFC to discuss this in detail would be nice, WDYT @ywang96 ?

For those of who want a temporary fix for this, here's how I do it, then reinstall vllm. Also expect an official fix soon.
whyiug@3495e80#diff-31e6bd0df09a47b5587701203d558701ac46e4f85bf7db83632da9990eaef198R382

Alvant pushed a commit to compressa-ai/vllm that referenced this pull request Oct 26, 2024
Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: Alvant <alvasian@yandex.ru>
garg-amit pushed a commit to garg-amit/vllm that referenced this pull request Oct 28, 2024
Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: Amit Garg <mitgarg17495@gmail.com>
@mgoin
Copy link
Member

mgoin commented Oct 29, 2024

EDIT: Nevermind, I just had a silly issue where weight_scale was being read as the weight parameter, PR here #9817

Hey @fyabc I am working on expanding quantization for multimodal models and currently this special case in the qwen2vl weight loading is causing issues

if "visual" in name and "qkv.weight" in name:
visual_num_heads = self.config.vision_config.num_heads
visual_embed_dim = self.config.vision_config.embed_dim
head_size = visual_embed_dim // visual_num_heads
loaded_weight = loaded_weight.view(3, visual_num_heads,
head_size,
visual_embed_dim)
loaded_weight = loaded_weight.transpose(0, 1)
loaded_weight = loaded_weight.reshape(-1, visual_embed_dim)
elif "visual" in name and "qkv.bias" in name:
visual_num_heads = self.config.vision_config.num_heads
visual_embed_dim = self.config.vision_config.embed_dim
head_size = visual_embed_dim // visual_num_heads
loaded_weight = loaded_weight.view(3, visual_num_heads,
head_size)
loaded_weight = loaded_weight.transpose(0, 1)
loaded_weight = loaded_weight.reshape(-1)

Could you offer insight into why this is required and if we could apply a transformation on the inputs rather than the weights?

LeiWang1999 pushed a commit to LeiWang1999/vllm-bitblas that referenced this pull request Mar 26, 2025
Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: LeiWang1999 <leiwang1999@outlook.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Bug]: Qwen2-VL AssertionError: assert "factor" in rope_scaling. [New Model]: Qwen2-VL