-
-
Notifications
You must be signed in to change notification settings - Fork 9k
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Your current environment
The output of `python collect_env.py`
Network isolation, unable to download
Python3.8
8*A10 GPU
Model:InternVL2-26B
vllm 0.5.5
vllm-flash-attn 2.6.1
torch 2.4.0
torchvision 0.19.0
🐛 Describe the bug
from dataclasses import dataclass
from typing import Literal
import torch
from PIL import Image
VLM_IMAGES_DIR = "vision_model_images"
@dataclass(frozen=True)
class ImageAsset:
name: Literal["stop_sign", "cherry_blossom"]
@property
def pil_image(self) -> Image.Image:
image_path = "image.jpg"
return Image.open(image_path)
"""
This example shows how to use vLLM for running offline inference
with the correct prompt format on vision language models.
For most models, the prompt format should follow corresponding examples
on HuggingFace model repository.
"""
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
# Input image and question
image = ImageAsset("cherry_blossom").pil_image.convert("RGB")
question = "What is the content of this image?"
# InternVL
def run_internvl(question):
model_name = "/home/tdj/model/InternVL2-26B"
llm = LLM(
model=model_name,trust_remote_code=True,
gpu_memory_utilization=0.9,tensor_parallel_size=8
)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
messages = [{"role": "user", "content": f"<image>\n{question}"}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
# Stop tokens for InternVL
# models variants may have different stop tokens
# please refer to the model card for the correct "stop words":
# https://huggingface.co/OpenGVLab/InternVL2-2B#service
stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"]
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
return llm, prompt, stop_token_ids
model_example_map = {
"internvl_chat": run_internvl,
}
def main():
model = "internvl_chat"
if model not in model_example_map:
raise ValueError(f"Model type {model} is not supported.")
llm, prompt, stop_token_ids = model_example_map[model](question)
# We set temperature to 0.2 so that outputs can be different
# even when all prompts are identical when running batch inference.
sampling_params = SamplingParams(
temperature=0.2, max_tokens=64, stop_token_ids=stop_token_ids
)
# Single inference
inputs = {
"prompt": prompt,
"multi_modal_data": {"image": image},
}
outputs = llm.generate(inputs, sampling_params=sampling_params)
for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)
if __name__ == "__main__":
main()
here is my error stack trace
(VllmWorkerProcess pid=85388) ERROR 08-29 21:27:26 multiproc_worker_utils.py:226] File "/root/anaconda3/envs/py3/lib/python3.8/site-packages/vllm/model_executor/models/utils.py", line 77, in merge_multimodal_embeddings
(VllmWorkerProcess pid=85386) ERROR 08-29 21:27:26 multiproc_worker_utils.py:226] File "/root/anaconda3/envs/py3/lib/python3.8/site-packages/vllm/executor/multiproc_worker_utils.py", line 223, in _run_worker_process
(VllmWorkerProcess pid=85388) ERROR 08-29 21:27:26 multiproc_worker_utils.py:226] raise ValueError(
(VllmWorkerProcess pid=85386) ERROR 08-29 21:27:26 multiproc_worker_utils.py:226] output = executor(*args, **kwargs)
(VllmWorkerProcess pid=85388) ERROR 08-29 21:27:26 multiproc_worker_utils.py:226] ValueError: Attempted to assign 7 x 256 = 1792 multimodal tokens to 506 placeholders
(VllmWorkerProcess pid=85386) ERROR 08-29 21:27:26 multiproc_worker_utils.py:226] File "/root/anaconda3/envs/py3/lib/python3.8/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorkerProcess pid=85388) ERROR 08-29 21:27:26 multiproc_worker_utils.py:226]
(VllmWorkerProcess pid=85386) ERROR 08-29 21:27:26 multiproc_worker_utils.py:226] return func(*args, **kwargs)
(VllmWorkerProcess pid=85386) ERROR 08-29 21:27:26 multiproc_worker_utils.py:226] File "/root/anaconda3/envs/py3/lib/python3.8/site-packages/vllm/worker/worker_base.py", line 69, in start_worker_execution_loop
(VllmWorkerProcess pid=85386) ERROR 08-29 21:27:26 multiproc_worker_utils.py:226] output = self.execute_model(execute_model_req=None)
(VllmWorkerProcess pid=85386) ERROR 08-29 21:27:26 multiproc_worker_utils.py:226] File "/root/anaconda3/envs/py3/lib/python3.8/site-packages/vllm/worker/worker_base.py", line 322, in execute_model
(VllmWorkerProcess pid=85386) ERROR 08-29 21:27:26 multiproc_worker_utils.py:226] output = self.model_runner.execute_model(
(VllmWorkerProcess pid=85386) ERROR 08-29 21:27:26 multiproc_worker_utils.py:226] File "/root/anaconda3/envs/py3/lib/python3.8/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorkerProcess pid=85386) ERROR 08-29 21:27:26 multiproc_worker_utils.py:226] return func(*args, **kwargs)
(VllmWorkerProcess pid=85386) ERROR 08-29 21:27:26 multiproc_worker_utils.py:226] File "/root/anaconda3/envs/py3/lib/python3.8/site-packages/vllm/worker/model_runner.py", line 1415, in execute_model
(VllmWorkerProcess pid=85386) ERROR 08-29 21:27:26 multiproc_worker_utils.py:226] hidden_or_intermediate_states = model_executable(
(VllmWorkerProcess pid=85386) ERROR 08-29 21:27:26 multiproc_worker_utils.py:226] File "/root/anaconda3/envs/py3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
(VllmWorkerProcess pid=85386) ERROR 08-29 21:27:26 multiproc_worker_utils.py:226] return self._call_impl(*args, **kwargs)
(VllmWorkerProcess pid=85386) ERROR 08-29 21:27:26 multiproc_worker_utils.py:226] File "/root/anaconda3/envs/py3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
(VllmWorkerProcess pid=85386) ERROR 08-29 21:27:26 multiproc_worker_utils.py:226] return forward_call(*args, **kwargs)
(VllmWorkerProcess pid=85386) ERROR 08-29 21:27:26 multiproc_worker_utils.py:226] File "/root/anaconda3/envs/py3/lib/python3.8/site-packages/vllm/model_executor/models/internvl.py", line 459, in forward
(VllmWorkerProcess pid=85386) ERROR 08-29 21:27:26 multiproc_worker_utils.py:226] inputs_embeds = merge_multimodal_embeddings(
(VllmWorkerProcess pid=85386) ERROR 08-29 21:27:26 multiproc_worker_utils.py:226] File "/root/anaconda3/envs/py3/lib/python3.8/site-packages/vllm/model_executor/models/utils.py", line 77, in merge_multimodal_embeddings
(VllmWorkerProcess pid=85386) ERROR 08-29 21:27:26 multiproc_worker_utils.py:226] raise ValueError(
(VllmWorkerProcess pid=85386) ERROR 08-29 21:27:26 multiproc_worker_utils.py:226] ValueError: Attempted to assign 7 x 256 = 1792 multimodal tokens to 506 placeholders
(VllmWorkerProcess pid=85386) ERROR 08-29 21:27:26 multiproc_worker_utils.py:226]
(VllmWorkerProcess pid=85392) ERROR 08-29 21:27:26 multiproc_worker_utils.py:226] Exception in worker VllmWorkerProcess while processing method start_worker_execution_loop: Attempted to assign 7 x 256 = 1792 multimodal tokens to 506 placeholders, Traceback (most recent call last):
If you have any questions, please feel free to contact me. I will run it exactly according to the official demo. The pictures are from my local
#6321
Is only 2B supported?
Before submitting a new issue...
- Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working