Skip to content

Commit

Permalink
Pixtral (vllm-project#8377)
Browse files Browse the repository at this point in the history
Co-authored-by: Roger Wang <ywang@roblox.com>
  • Loading branch information
patrickvonplaten and ywang96 authored Sep 11, 2024
1 parent 775f00f commit d394787
Show file tree
Hide file tree
Showing 8 changed files with 807 additions and 9 deletions.
5 changes: 5 additions & 0 deletions docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,11 @@ Multimodal Language Models
- Image\ :sup:`E+`
- :code:`microsoft/Phi-3-vision-128k-instruct`, :code:`microsoft/Phi-3.5-vision-instruct` etc.
-
* - :code:`PixtralForConditionalGeneration`
- Pixtral
- Image\ :sup:`+`
- :code:`mistralai/Pixtral-12B-2409`
-
* - :code:`QWenLMHeadModel`
- Qwen-VL
- Image\ :sup:`E`
Expand Down
164 changes: 164 additions & 0 deletions examples/offline_inference_pixtral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
# ruff: noqa
import argparse

from vllm import LLM
from vllm.sampling_params import SamplingParams

# This script is an offline demo for running Pixtral.
#
# If you want to run a server/client setup, please follow this code:
#
# - Server:
#
# ```bash
# vllm serve mistralai/Pixtral-12B-2409 --tokenizer_mode mistral --limit_mm_per_prompt 'image=4' --max_num_batched_tokens 16384
# ```
#
# - Client:
#
# ```bash
# curl --location 'http://<your-node-url>:8000/v1/chat/completions' \
# --header 'Content-Type: application/json' \
# --header 'Authorization: Bearer token' \
# --data '{
# "model": "mistralai/Pixtral-12B-2409",
# "messages": [
# {
# "role": "user",
# "content": [
# {"type" : "text", "text": "Describe this image in detail please."},
# {"type": "image_url", "image_url": {"url": "https://s3.amazonaws.com/cms.ipressroom.com/338/files/201808/5b894ee1a138352221103195_A680%7Ejogging-edit/A680%7Ejogging-edit_hero.jpg"}},
# {"type" : "text", "text": "and this one as well. Answer in French."},
# {"type": "image_url", "image_url": {"url": "https://www.wolframcloud.com/obj/resourcesystem/images/a0e/a0ee3983-46c6-4c92-b85d-059044639928/6af8cfb971db031b.png"}}
# ]
# }
# ]
# }'
# ```
#
# Usage:
# python demo.py simple
# python demo.py advanced


def run_simple_demo():
model_name = "mistralai/Pixtral-12B-2409"
sampling_params = SamplingParams(max_tokens=8192)

llm = LLM(model=model_name, tokenizer_mode="mistral")

prompt = "Describe this image in one sentence."
image_url = "https://picsum.photos/id/237/200/300"

messages = [
{
"role":
"user",
"content": [
{
"type": "text",
"text": prompt
},
{
"type": "image_url",
"image_url": {
"url": image_url
}
},
],
},
]
outputs = llm.chat(messages, sampling_params=sampling_params)

print(outputs[0].outputs[0].text)


def run_advanced_demo():
model_name = "mistralai/Pixtral-12B-2409"
max_img_per_msg = 5
max_tokens_per_img = 4096

sampling_params = SamplingParams(max_tokens=8192, temperature=0.7)
llm = LLM(
model=model_name,
tokenizer_mode="mistral",
limit_mm_per_prompt={"image": max_img_per_msg},
max_num_batched_tokens=max_img_per_msg * max_tokens_per_img,
)

prompt = "Describe the following image."

url_1 = "https://huggingface.co/datasets/patrickvonplaten/random_img/resolve/main/yosemite.png"
url_2 = "https://picsum.photos/seed/picsum/200/300"
url_3 = "https://picsum.photos/id/32/512/512"

messages = [
{
"role":
"user",
"content": [
{
"type": "text",
"text": prompt
},
{
"type": "image_url",
"image_url": {
"url": url_1
}
},
{
"type": "image_url",
"image_url": {
"url": url_2
}
},
],
},
{
"role": "assistant",
"content": "The images show nature.",
},
{
"role": "user",
"content": "More details please and answer only in French!.",
},
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": url_3
}
},
],
},
]

outputs = llm.chat(messages=messages, sampling_params=sampling_params)
print(outputs[0].outputs[0].text)


def main():
parser = argparse.ArgumentParser(
description="Run a demo in simple or advanced mode.")

parser.add_argument(
"mode",
choices=["simple", "advanced"],
help="Specify the demo mode: 'simple' or 'advanced'",
)

args = parser.parse_args()

if args.mode == "simple":
print("Running simple demo...")
run_simple_demo()
elif args.mode == "advanced":
print("Running advanced demo...")
run_advanced_demo()


if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion requirements-common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ pyzmq
msgspec
gguf == 0.9.1
importlib_metadata
mistral_common >= 1.3.4
mistral_common >= 1.4.0
pyyaml
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12
einops # Required for Qwen2-VL.
64 changes: 64 additions & 0 deletions tests/models/test_pixtral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""Compare the outputs of HF and vLLM for Mistral models using greedy sampling.
Run `pytest tests/models/test_mistral.py`.
"""
import pytest

from vllm.sampling_params import SamplingParams

pytestmark = pytest.mark.vlm

MODELS = ["mistralai/Pixtral-12B-2409"]


@pytest.mark.skip(
reason=
"Model is too big, test passed on A100 locally but will OOM on CI machine."
)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
num_logprobs: int,
) -> None:
image_urls = [
"https://picsum.photos/id/237/200/300",
"https://picsum.photos/seed/picsum/200/300"
]
expected = [
"The image depicts a black dog lying on a wooden surface, looking directly at the camera with a calm expression.", # noqa
"The image depicts a serene landscape with a snow-covered mountain under a pastel-colored sky during sunset." # noqa
]
prompt = "Describe the image in one short sentence."

sampling_params = SamplingParams(max_tokens=512, temperature=0.0)

with vllm_runner(model, dtype=dtype,
tokenizer_mode="mistral") as vllm_model:

for i, image_url in enumerate(image_urls):
messages = [
{
"role":
"user",
"content": [{
"type": "text",
"text": prompt
}, {
"type": "image_url",
"image_url": {
"url": image_url
}
}]
},
]

outputs = vllm_model.model.chat(messages,
sampling_params=sampling_params)
assert outputs[0].outputs[0].text == expected[i]
3 changes: 2 additions & 1 deletion vllm/entrypoints/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,8 @@ def _placeholder_str(self, modality: ModalityStr,
return f"<|image_{current_count}|>"
if model_type == "minicpmv":
return "(<image>./</image>)"
if model_type in ("blip-2", "chatglm", "fuyu", "paligemma"):
if model_type in ("blip-2", "chatglm", "fuyu", "paligemma",
"pixtral"):
# These models do not use image tokens in the prompt
return None
if model_type == "qwen":
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
"UltravoxModel": ("ultravox", "UltravoxModel"),
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
"PixtralForConditionalGeneration": ("pixtral",
"PixtralForConditionalGeneration"),
"Qwen2VLForConditionalGeneration": ("qwen2_vl",
"Qwen2VLForConditionalGeneration"),
}
Expand Down
Loading

0 comments on commit d394787

Please sign in to comment.