Skip to content

Commit

Permalink
[Model] Add UltravoxModel and UltravoxConfig (vllm-project#7615)
Browse files Browse the repository at this point in the history
Signed-off-by: Alvant <alvasian@yandex.ru>
  • Loading branch information
petersalas authored and Alvant committed Oct 26, 2024
1 parent a3d7058 commit 1af67c6
Show file tree
Hide file tree
Showing 33 changed files with 1,090 additions and 264 deletions.
7 changes: 6 additions & 1 deletion docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ Multimodal Language Models

* - Architecture
- Models
- Supported Modality(ies)
- Supported Modalities
- Example HuggingFace Models
- :ref:`LoRA <lora>`
* - :code:`Blip2ForConditionalGeneration`
Expand Down Expand Up @@ -234,6 +234,11 @@ Multimodal Language Models
- Image
- :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, :code:`openbmb/MiniCPM-V-2_6`, etc.
-
* - :code: `UltravoxModel`
- Ultravox
- Audio
- :code: `fixie-ai/ultravox-v0_3`
-

.. note::
For :code:`openbmb/MiniCPM-V-2`, the official repo doesn't work yet, so we need to use a fork (:code:`HwwwH/MiniCPM-V-2`) for now.
Expand Down
97 changes: 97 additions & 0 deletions examples/offline_inference_audio_language.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
"""
This example shows how to use vLLM for running offline inference
with the correct prompt format on vision language models.
For most models, the prompt format should follow corresponding examples
on HuggingFace model repository.
"""
from transformers import AutoTokenizer

from vllm import LLM, SamplingParams
from vllm.assets.audio import AudioAsset
from vllm.utils import FlexibleArgumentParser

# Input audio and question
audio_and_sample_rate = AudioAsset("mary_had_lamb").audio_and_sample_rate
question = "What is recited in the audio?"


# Ultravox 0.3
def run_ultravox(question):
model_name = "fixie-ai/ultravox-v0_3"

tokenizer = AutoTokenizer.from_pretrained(model_name)
messages = [{
'role': 'user',
'content': f"<|reserved_special_token_0|>\n{question}"
}]
prompt = tokenizer.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True)

llm = LLM(model=model_name)
stop_token_ids = None
return llm, prompt, stop_token_ids


model_example_map = {
"ultravox": run_ultravox,
}


def main(args):
model = args.model_type
if model not in model_example_map:
raise ValueError(f"Model type {model} is not supported.")

llm, prompt, stop_token_ids = model_example_map[model](question)

# We set temperature to 0.2 so that outputs can be different
# even when all prompts are identical when running batch inference.
sampling_params = SamplingParams(temperature=0.2,
max_tokens=64,
stop_token_ids=stop_token_ids)

assert args.num_prompts > 0
if args.num_prompts == 1:
# Single inference
inputs = {
"prompt": prompt,
"multi_modal_data": {
"audio": audio_and_sample_rate
},
}

else:
# Batch inference
inputs = [{
"prompt": prompt,
"multi_modal_data": {
"audio": audio_and_sample_rate
},
} for _ in range(args.num_prompts)]

outputs = llm.generate(inputs, sampling_params=sampling_params)

for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)


if __name__ == "__main__":
parser = FlexibleArgumentParser(
description='Demo on using vLLM for offline inference with '
'audio language models')
parser.add_argument('--model-type',
'-m',
type=str,
default="ultravox",
choices=model_example_map.keys(),
help='Huggingface "model_type".')
parser.add_argument('--num-prompts',
type=int,
default=1,
help='Number of prompts to run.')

args = parser.parse_args()
main(args)
90 changes: 90 additions & 0 deletions examples/openai_audio_api_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
"""An example showing how to use vLLM to serve VLMs.
Launch the vLLM server with the following command:
vllm serve fixie-ai/ultravox-v0_3
"""
import base64

import requests
from openai import OpenAI

from vllm.assets.audio import AudioAsset

# Modify OpenAI's API key and API base to use vLLM's API server.
openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8000/v1"

client = OpenAI(
# defaults to os.environ.get("OPENAI_API_KEY")
api_key=openai_api_key,
base_url=openai_api_base,
)

models = client.models.list()
model = models.data[0].id

# Any format supported by librosa is supported
audio_url = AudioAsset("winning_call").url

# Use audio url in the payload
chat_completion_from_url = client.chat.completions.create(
messages=[{
"role":
"user",
"content": [
{
"type": "text",
"text": "What's in this audio?"
},
{
"type": "audio_url",
"audio_url": {
"url": audio_url
},
},
],
}],
model=model,
max_tokens=64,
)

result = chat_completion_from_url.choices[0].message.content
print(f"Chat completion output:{result}")


# Use base64 encoded audio in the payload
def encode_audio_base64_from_url(audio_url: str) -> str:
"""Encode an audio retrieved from a remote url to base64 format."""

with requests.get(audio_url) as response:
response.raise_for_status()
result = base64.b64encode(response.content).decode('utf-8')

return result


audio_base64 = encode_audio_base64_from_url(audio_url=audio_url)
chat_completion_from_base64 = client.chat.completions.create(
messages=[{
"role":
"user",
"content": [
{
"type": "text",
"text": "What's in this audio?"
},
{
"type": "audio_url",
"audio_url": {
# Any format supported by librosa is supported
"url": f"data:audio/ogg;base64,{audio_base64}"
},
},
],
}],
model=model,
max_tokens=64,
)

result = chat_completion_from_base64.choices[0].message.content
print(f"Chat completion output:{result}")
31 changes: 19 additions & 12 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@
from typing import (Any, Callable, Dict, List, Optional, Tuple, TypedDict,
TypeVar, Union)

import numpy as np
import pytest
import torch
import torch.nn as nn
import torch.nn.functional as F
from huggingface_hub import snapshot_download
from PIL import Image
from transformers import (AutoModelForCausalLM, AutoModelForSeq2SeqLM,
AutoModelForVision2Seq, AutoTokenizer, BatchEncoding,
from transformers import (AutoModelForCausalLM, AutoTokenizer, BatchEncoding,
BatchFeature)

from vllm import LLM, SamplingParams
Expand Down Expand Up @@ -216,8 +216,7 @@ def __init__(
*,
model_kwargs: Optional[Dict[str, Any]] = None,
is_embedding_model: bool = False,
is_vision_model: bool = False,
is_encoder_decoder_model: bool = False,
auto_cls=AutoModelForCausalLM,
postprocess_inputs: Callable[[BatchEncoding],
BatchEncoding] = identity,
) -> None:
Expand All @@ -234,13 +233,6 @@ def __init__(
device="cpu",
).to(dtype=torch_dtype))
else:
if is_vision_model:
auto_cls = AutoModelForVision2Seq
elif is_encoder_decoder_model:
auto_cls = AutoModelForSeq2SeqLM
else:
auto_cls = AutoModelForCausalLM

model_kwargs = model_kwargs if model_kwargs is not None else {}
self.model = self.wrap_device(
auto_cls.from_pretrained(
Expand Down Expand Up @@ -432,6 +424,7 @@ def generate_greedy_logprobs_limit(
max_tokens: int,
num_logprobs: int,
images: Optional[List[Image.Image]] = None,
audios: Optional[List[Tuple[np.ndarray, int]]] = None,
**kwargs: Any,
) -> List[Tuple[List[int], str, List[Dict[int, float]]]]:
all_logprobs: List[List[Dict[int, float]]] = []
Expand All @@ -446,6 +439,11 @@ def generate_greedy_logprobs_limit(
if images is not None and images[i] is not None:
processor_kwargs["images"] = images[i]

if audios is not None:
audio, sr = audios[i]
processor_kwargs["audio"] = audio
processor_kwargs["sampling_rate"] = sr

inputs = self.processor(**processor_kwargs)
inputs = self.postprocess_inputs(inputs)

Expand Down Expand Up @@ -627,6 +625,8 @@ def generate_w_logprobs(
sampling_params: SamplingParams,
images: Optional[Union[List[Image.Image],
List[List[Image.Image]]]] = None,
audios: Optional[Union[List[Tuple[np.ndarray, int]],
List[List[Tuple[np.ndarray, int]]]]] = None
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
assert sampling_params.logprobs is not None

Expand All @@ -638,6 +638,10 @@ def generate_w_logprobs(
for i, image in enumerate(images):
inputs[i]["multi_modal_data"] = {"image": image}

if audios is not None:
for i, audio in enumerate(audios):
inputs[i]["multi_modal_data"] = {"audio": audio}

req_outputs = self.model.generate(inputs,
sampling_params=sampling_params)
return self._final_steps_generate_w_logprobs(req_outputs)
Expand Down Expand Up @@ -674,6 +678,8 @@ def generate_greedy_logprobs(
num_logprobs: int,
images: Optional[Union[List[Image.Image],
List[List[Image.Image]]]] = None,
audios: Optional[Union[List[Tuple[np.ndarray, int]],
List[List[Tuple[np.ndarray, int]]]]] = None,
stop_token_ids: Optional[List[int]] = None,
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
greedy_logprobs_params = SamplingParams(temperature=0.0,
Expand All @@ -682,7 +688,8 @@ def generate_greedy_logprobs(
stop_token_ids=stop_token_ids)
outputs = self.generate_w_logprobs(prompts,
greedy_logprobs_params,
images=images)
images=images,
audios=audios)

return [(output_ids, output_str, output_logprobs)
for output_ids, output_str, output_logprobs in outputs]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
"""

import pytest
from transformers import AutoModelForSeq2SeqLM

from vllm.utils import cuda_device_count_stateless

Expand Down Expand Up @@ -85,7 +86,7 @@ def test_models(
}

with hf_runner(model, dtype=dtype,
is_encoder_decoder_model=True) as hf_model:
auto_cls=AutoModelForSeq2SeqLM) as hf_model:
hf_outputs = (hf_model.generate_encoder_decoder_greedy_logprobs_limit(
test_prompts,
max_tokens,
Expand Down
Loading

0 comments on commit 1af67c6

Please sign in to comment.