From 1af67c6157a5a0cd72421986a0a896b46ef69fc8 Mon Sep 17 00:00:00 2001 From: Peter Salas Date: Wed, 21 Aug 2024 15:49:39 -0700 Subject: [PATCH] [Model] Add UltravoxModel and UltravoxConfig (#7615) Signed-off-by: Alvant --- docs/source/models/supported_models.rst | 7 +- examples/offline_inference_audio_language.py | 97 ++++ examples/openai_audio_api_client.py | 90 ++++ tests/conftest.py | 31 +- ...t_basic_distributed_correctness_enc_dec.py | 3 +- tests/entrypoints/openai/test_audio.py | 148 +----- tests/models/test_bart.py | 3 +- tests/models/test_blip2.py | 5 +- tests/models/test_chameleon.py | 4 +- tests/models/test_llava.py | 5 +- tests/models/test_llava_image_embeds.py | 5 +- tests/models/test_llava_next.py | 5 +- tests/models/test_paligemma.py | 5 +- tests/models/test_qwen.py | 2 +- tests/models/test_ultravox.py | 151 ++++++ vllm/assets/audio.py | 26 ++ vllm/entrypoints/chat_utils.py | 6 +- vllm/model_executor/models/__init__.py | 3 +- vllm/model_executor/models/blip.py | 8 +- vllm/model_executor/models/chameleon.py | 8 +- vllm/model_executor/models/clip.py | 8 +- vllm/model_executor/models/fuyu.py | 4 +- vllm/model_executor/models/internvl.py | 2 +- vllm/model_executor/models/minicpmv.py | 4 +- vllm/model_executor/models/paligemma.py | 2 +- vllm/model_executor/models/phi3v.py | 2 +- vllm/model_executor/models/siglip.py | 8 +- vllm/model_executor/models/ultravox.py | 435 ++++++++++++++++++ vllm/multimodal/image.py | 83 ---- vllm/multimodal/utils.py | 90 +++- vllm/transformers_utils/config.py | 3 +- vllm/transformers_utils/configs/__init__.py | 2 + vllm/transformers_utils/configs/ultravox.py | 99 ++++ 33 files changed, 1090 insertions(+), 264 deletions(-) create mode 100644 examples/offline_inference_audio_language.py create mode 100644 examples/openai_audio_api_client.py create mode 100644 tests/models/test_ultravox.py create mode 100644 vllm/assets/audio.py create mode 100644 vllm/model_executor/models/ultravox.py create mode 100644 vllm/transformers_utils/configs/ultravox.py diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index c761d1b32cd91..1692e13c4ec06 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -186,7 +186,7 @@ Multimodal Language Models * - Architecture - Models - - Supported Modality(ies) + - Supported Modalities - Example HuggingFace Models - :ref:`LoRA ` * - :code:`Blip2ForConditionalGeneration` @@ -234,6 +234,11 @@ Multimodal Language Models - Image - :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, :code:`openbmb/MiniCPM-V-2_6`, etc. - + * - :code: `UltravoxModel` + - Ultravox + - Audio + - :code: `fixie-ai/ultravox-v0_3` + - .. note:: For :code:`openbmb/MiniCPM-V-2`, the official repo doesn't work yet, so we need to use a fork (:code:`HwwwH/MiniCPM-V-2`) for now. diff --git a/examples/offline_inference_audio_language.py b/examples/offline_inference_audio_language.py new file mode 100644 index 0000000000000..7b886f8e2001a --- /dev/null +++ b/examples/offline_inference_audio_language.py @@ -0,0 +1,97 @@ +""" +This example shows how to use vLLM for running offline inference +with the correct prompt format on vision language models. + +For most models, the prompt format should follow corresponding examples +on HuggingFace model repository. +""" +from transformers import AutoTokenizer + +from vllm import LLM, SamplingParams +from vllm.assets.audio import AudioAsset +from vllm.utils import FlexibleArgumentParser + +# Input audio and question +audio_and_sample_rate = AudioAsset("mary_had_lamb").audio_and_sample_rate +question = "What is recited in the audio?" + + +# Ultravox 0.3 +def run_ultravox(question): + model_name = "fixie-ai/ultravox-v0_3" + + tokenizer = AutoTokenizer.from_pretrained(model_name) + messages = [{ + 'role': 'user', + 'content': f"<|reserved_special_token_0|>\n{question}" + }] + prompt = tokenizer.apply_chat_template(messages, + tokenize=False, + add_generation_prompt=True) + + llm = LLM(model=model_name) + stop_token_ids = None + return llm, prompt, stop_token_ids + + +model_example_map = { + "ultravox": run_ultravox, +} + + +def main(args): + model = args.model_type + if model not in model_example_map: + raise ValueError(f"Model type {model} is not supported.") + + llm, prompt, stop_token_ids = model_example_map[model](question) + + # We set temperature to 0.2 so that outputs can be different + # even when all prompts are identical when running batch inference. + sampling_params = SamplingParams(temperature=0.2, + max_tokens=64, + stop_token_ids=stop_token_ids) + + assert args.num_prompts > 0 + if args.num_prompts == 1: + # Single inference + inputs = { + "prompt": prompt, + "multi_modal_data": { + "audio": audio_and_sample_rate + }, + } + + else: + # Batch inference + inputs = [{ + "prompt": prompt, + "multi_modal_data": { + "audio": audio_and_sample_rate + }, + } for _ in range(args.num_prompts)] + + outputs = llm.generate(inputs, sampling_params=sampling_params) + + for o in outputs: + generated_text = o.outputs[0].text + print(generated_text) + + +if __name__ == "__main__": + parser = FlexibleArgumentParser( + description='Demo on using vLLM for offline inference with ' + 'audio language models') + parser.add_argument('--model-type', + '-m', + type=str, + default="ultravox", + choices=model_example_map.keys(), + help='Huggingface "model_type".') + parser.add_argument('--num-prompts', + type=int, + default=1, + help='Number of prompts to run.') + + args = parser.parse_args() + main(args) diff --git a/examples/openai_audio_api_client.py b/examples/openai_audio_api_client.py new file mode 100644 index 0000000000000..80a972683871f --- /dev/null +++ b/examples/openai_audio_api_client.py @@ -0,0 +1,90 @@ +"""An example showing how to use vLLM to serve VLMs. + +Launch the vLLM server with the following command: +vllm serve fixie-ai/ultravox-v0_3 +""" +import base64 + +import requests +from openai import OpenAI + +from vllm.assets.audio import AudioAsset + +# Modify OpenAI's API key and API base to use vLLM's API server. +openai_api_key = "EMPTY" +openai_api_base = "http://localhost:8000/v1" + +client = OpenAI( + # defaults to os.environ.get("OPENAI_API_KEY") + api_key=openai_api_key, + base_url=openai_api_base, +) + +models = client.models.list() +model = models.data[0].id + +# Any format supported by librosa is supported +audio_url = AudioAsset("winning_call").url + +# Use audio url in the payload +chat_completion_from_url = client.chat.completions.create( + messages=[{ + "role": + "user", + "content": [ + { + "type": "text", + "text": "What's in this audio?" + }, + { + "type": "audio_url", + "audio_url": { + "url": audio_url + }, + }, + ], + }], + model=model, + max_tokens=64, +) + +result = chat_completion_from_url.choices[0].message.content +print(f"Chat completion output:{result}") + + +# Use base64 encoded audio in the payload +def encode_audio_base64_from_url(audio_url: str) -> str: + """Encode an audio retrieved from a remote url to base64 format.""" + + with requests.get(audio_url) as response: + response.raise_for_status() + result = base64.b64encode(response.content).decode('utf-8') + + return result + + +audio_base64 = encode_audio_base64_from_url(audio_url=audio_url) +chat_completion_from_base64 = client.chat.completions.create( + messages=[{ + "role": + "user", + "content": [ + { + "type": "text", + "text": "What's in this audio?" + }, + { + "type": "audio_url", + "audio_url": { + # Any format supported by librosa is supported + "url": f"data:audio/ogg;base64,{audio_base64}" + }, + }, + ], + }], + model=model, + max_tokens=64, +) + +result = chat_completion_from_base64.choices[0].message.content +print(f"Chat completion output:{result}") diff --git a/tests/conftest.py b/tests/conftest.py index 08a2c8fcda021..ae362b228d9d8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,14 +9,14 @@ from typing import (Any, Callable, Dict, List, Optional, Tuple, TypedDict, TypeVar, Union) +import numpy as np import pytest import torch import torch.nn as nn import torch.nn.functional as F from huggingface_hub import snapshot_download from PIL import Image -from transformers import (AutoModelForCausalLM, AutoModelForSeq2SeqLM, - AutoModelForVision2Seq, AutoTokenizer, BatchEncoding, +from transformers import (AutoModelForCausalLM, AutoTokenizer, BatchEncoding, BatchFeature) from vllm import LLM, SamplingParams @@ -216,8 +216,7 @@ def __init__( *, model_kwargs: Optional[Dict[str, Any]] = None, is_embedding_model: bool = False, - is_vision_model: bool = False, - is_encoder_decoder_model: bool = False, + auto_cls=AutoModelForCausalLM, postprocess_inputs: Callable[[BatchEncoding], BatchEncoding] = identity, ) -> None: @@ -234,13 +233,6 @@ def __init__( device="cpu", ).to(dtype=torch_dtype)) else: - if is_vision_model: - auto_cls = AutoModelForVision2Seq - elif is_encoder_decoder_model: - auto_cls = AutoModelForSeq2SeqLM - else: - auto_cls = AutoModelForCausalLM - model_kwargs = model_kwargs if model_kwargs is not None else {} self.model = self.wrap_device( auto_cls.from_pretrained( @@ -432,6 +424,7 @@ def generate_greedy_logprobs_limit( max_tokens: int, num_logprobs: int, images: Optional[List[Image.Image]] = None, + audios: Optional[List[Tuple[np.ndarray, int]]] = None, **kwargs: Any, ) -> List[Tuple[List[int], str, List[Dict[int, float]]]]: all_logprobs: List[List[Dict[int, float]]] = [] @@ -446,6 +439,11 @@ def generate_greedy_logprobs_limit( if images is not None and images[i] is not None: processor_kwargs["images"] = images[i] + if audios is not None: + audio, sr = audios[i] + processor_kwargs["audio"] = audio + processor_kwargs["sampling_rate"] = sr + inputs = self.processor(**processor_kwargs) inputs = self.postprocess_inputs(inputs) @@ -627,6 +625,8 @@ def generate_w_logprobs( sampling_params: SamplingParams, images: Optional[Union[List[Image.Image], List[List[Image.Image]]]] = None, + audios: Optional[Union[List[Tuple[np.ndarray, int]], + List[List[Tuple[np.ndarray, int]]]]] = None ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: assert sampling_params.logprobs is not None @@ -638,6 +638,10 @@ def generate_w_logprobs( for i, image in enumerate(images): inputs[i]["multi_modal_data"] = {"image": image} + if audios is not None: + for i, audio in enumerate(audios): + inputs[i]["multi_modal_data"] = {"audio": audio} + req_outputs = self.model.generate(inputs, sampling_params=sampling_params) return self._final_steps_generate_w_logprobs(req_outputs) @@ -674,6 +678,8 @@ def generate_greedy_logprobs( num_logprobs: int, images: Optional[Union[List[Image.Image], List[List[Image.Image]]]] = None, + audios: Optional[Union[List[Tuple[np.ndarray, int]], + List[List[Tuple[np.ndarray, int]]]]] = None, stop_token_ids: Optional[List[int]] = None, ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: greedy_logprobs_params = SamplingParams(temperature=0.0, @@ -682,7 +688,8 @@ def generate_greedy_logprobs( stop_token_ids=stop_token_ids) outputs = self.generate_w_logprobs(prompts, greedy_logprobs_params, - images=images) + images=images, + audios=audios) return [(output_ids, output_str, output_logprobs) for output_ids, output_str, output_logprobs in outputs] diff --git a/tests/distributed/test_basic_distributed_correctness_enc_dec.py b/tests/distributed/test_basic_distributed_correctness_enc_dec.py index 9850c823ff5da..f00d5ef584a2a 100644 --- a/tests/distributed/test_basic_distributed_correctness_enc_dec.py +++ b/tests/distributed/test_basic_distributed_correctness_enc_dec.py @@ -10,6 +10,7 @@ """ import pytest +from transformers import AutoModelForSeq2SeqLM from vllm.utils import cuda_device_count_stateless @@ -85,7 +86,7 @@ def test_models( } with hf_runner(model, dtype=dtype, - is_encoder_decoder_model=True) as hf_model: + auto_cls=AutoModelForSeq2SeqLM) as hf_model: hf_outputs = (hf_model.generate_encoder_decoder_greedy_logprobs_limit( test_prompts, max_tokens, diff --git a/tests/entrypoints/openai/test_audio.py b/tests/entrypoints/openai/test_audio.py index 39b47f3033715..6dc8dde667389 100644 --- a/tests/entrypoints/openai/test_audio.py +++ b/tests/entrypoints/openai/test_audio.py @@ -1,138 +1,36 @@ -import math -import sys -import time -from typing import Dict, List, Optional, Tuple, Union, cast -from unittest.mock import patch - -import librosa -import numpy as np +from typing import Dict, List + import openai import pytest -import requests -import torch - -from vllm import ModelRegistry -from vllm.config import MultiModalConfig -from vllm.inputs import INPUT_REGISTRY -from vllm.inputs.data import LLMInputs -from vllm.inputs.registry import InputContext -from vllm.model_executor.models.interfaces import SupportsMultiModal -from vllm.model_executor.models.opt import OPTForCausalLM -from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.base import MultiModalInputs -from vllm.multimodal.image import (cached_get_tokenizer, - repeat_and_pad_image_tokens) -from vllm.multimodal.utils import encode_audio_base64, fetch_audio -from vllm.utils import get_open_port -from ...utils import VLLM_PATH +from vllm.assets.audio import AudioAsset +from vllm.multimodal.utils import encode_audio_base64, fetch_audio -chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja" -assert chatml_jinja_path.exists() +from ...utils import RemoteOpenAIServer -MODEL_NAME = "facebook/opt-125m" +MODEL_NAME = "fixie-ai/ultravox-v0_3" TEST_AUDIO_URLS = [ - "https://upload.wikimedia.org/wikipedia/en/b/bf/Dave_Niehaus_Winning_Call_1995_AL_Division_Series.ogg", + AudioAsset("winning_call").url, ] -def server_function(port): - - def fake_input_mapper(ctx: InputContext, data: object): - assert isinstance(data, tuple) - (audio, sr) = cast(Tuple[np.ndarray, Union[float, int]], data) - - # Resample it to 1 sample per second - audio = librosa.resample(audio, orig_sr=sr, target_sr=1) - return MultiModalInputs({"processed_audio": torch.from_numpy(audio)}) - - def fake_input_processor(ctx: InputContext, llm_inputs: LLMInputs): - multi_modal_data = llm_inputs.get("multi_modal_data") - if multi_modal_data is None or "audio" not in multi_modal_data: - return llm_inputs - - audio, sr = multi_modal_data.get("audio") - audio_duration = math.ceil(len(audio) / sr) - - new_prompt, new_token_ids = repeat_and_pad_image_tokens( - cached_get_tokenizer(ctx.model_config.tokenizer), - llm_inputs.get("prompt"), - llm_inputs["prompt_token_ids"], - image_token_id=62, # "_" - repeat_count=audio_duration) - - return LLMInputs(prompt_token_ids=new_token_ids, - prompt=new_prompt, - multi_modal_data=multi_modal_data) - - @MULTIMODAL_REGISTRY.register_input_mapper("audio", fake_input_mapper) - @MULTIMODAL_REGISTRY.register_max_multimodal_tokens( - "audio", lambda *_, **__: 100) - @INPUT_REGISTRY.register_input_processor(fake_input_processor) - class FakeAudioModel(OPTForCausalLM, SupportsMultiModal): - - def __init__(self, *args, multimodal_config: MultiModalConfig, - **kwargs): - assert multimodal_config is not None - super().__init__(*args, **kwargs) - - def forward( - self, - *args, - processed_audio: Optional[torch.Tensor] = None, - **kwargs, - ) -> torch.Tensor: - return super().forward(*args, **kwargs) - - ModelRegistry.register_model("OPTForCausalLM", FakeAudioModel) - - with patch( - "vllm.entrypoints.chat_utils._mm_token_str", - lambda *_, **__: "_"), patch( - "vllm.model_executor.models.ModelRegistry.is_multimodal_model" - ) as mock: - mock.return_value = True - sys.argv = ["placeholder.py"] + \ - (f"--model {MODEL_NAME} --gpu-memory-utilization 0.10 " - "--dtype bfloat16 --enforce-eager --api-key token-abc123 " - f"--port {port} --chat-template {chatml_jinja_path} " - "--disable-frontend-multiprocessing").split() - import runpy - runpy.run_module('vllm.entrypoints.openai.api_server', - run_name='__main__') +@pytest.fixture(scope="module") +def server(): + args = [ + "--dtype", + "bfloat16", + "--max-model-len", + "4096", + "--enforce-eager", + ] + + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server @pytest.fixture(scope="module") -def client(): - port = get_open_port() - ctx = torch.multiprocessing.get_context("spawn") - server = ctx.Process(target=server_function, args=(port, )) - server.start() - MAX_SERVER_START_WAIT_S = 60 - client = openai.AsyncOpenAI( - base_url=f"http://localhost:{port}/v1", - api_key="token-abc123", - ) - # run health check - health_url = f"http://localhost:{port}/health" - start = time.time() - while True: - try: - if requests.get(health_url).status_code == 200: - break - except Exception as err: - result = server.exitcode - if result is not None: - raise RuntimeError("Server exited unexpectedly.") from err - - time.sleep(0.5) - if time.time() - start > MAX_SERVER_START_WAIT_S: - raise RuntimeError("Server failed to start in time.") from err - - try: - yield client - finally: - server.kill() +def client(server): + return server.get_async_client() @pytest.fixture(scope="session") @@ -176,7 +74,7 @@ async def test_single_chat_session_audio(client: openai.AsyncOpenAI, choice = chat_completion.choices[0] assert choice.finish_reason == "length" assert chat_completion.usage == openai.types.CompletionUsage( - completion_tokens=10, prompt_tokens=36, total_tokens=46) + completion_tokens=10, prompt_tokens=202, total_tokens=212) message = choice.message message = chat_completion.choices[0].message @@ -231,7 +129,7 @@ async def test_single_chat_session_audio_base64encoded( choice = chat_completion.choices[0] assert choice.finish_reason == "length" assert chat_completion.usage == openai.types.CompletionUsage( - completion_tokens=10, prompt_tokens=36, total_tokens=46) + completion_tokens=10, prompt_tokens=202, total_tokens=212) message = choice.message message = chat_completion.choices[0].message diff --git a/tests/models/test_bart.py b/tests/models/test_bart.py index 9bca5a86f1241..660b61d1a7ade 100644 --- a/tests/models/test_bart.py +++ b/tests/models/test_bart.py @@ -12,6 +12,7 @@ # (xFormers, etc.) import pytest + from transformers import AutoModelForSeq2SeqLM from vllm.sequence import SampleLogprobs @@ -131,7 +132,7 @@ def test_models( } with hf_runner(model, dtype=dtype, - is_encoder_decoder_model=True) as hf_model: + auto_cls=AutoModelForSeq2SeqLM) as hf_model: hf_outputs = ( hf_model.generate_encoder_decoder_greedy_logprobs_limit( test_case_prompts, diff --git a/tests/models/test_blip2.py b/tests/models/test_blip2.py index 64b7a77404b98..5d48bad0d7b35 100644 --- a/tests/models/test_blip2.py +++ b/tests/models/test_blip2.py @@ -1,7 +1,7 @@ from typing import List, Optional, Tuple import pytest -from transformers import AutoTokenizer +from transformers import AutoModelForVision2Seq, AutoTokenizer from vllm.multimodal.utils import rescale_image_size from vllm.sequence import SampleLogprobs @@ -80,7 +80,8 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, for prompts, images in inputs_per_image ] - with hf_runner(model, dtype=dtype, is_vision_model=True) as hf_model: + with hf_runner(model, dtype=dtype, + auto_cls=AutoModelForVision2Seq) as hf_model: hf_outputs_per_image = [ hf_model.generate_greedy_logprobs_limit(prompts, max_tokens, diff --git a/tests/models/test_chameleon.py b/tests/models/test_chameleon.py index 5e7e0e6258f8a..e02b4b1ed72bd 100644 --- a/tests/models/test_chameleon.py +++ b/tests/models/test_chameleon.py @@ -1,7 +1,7 @@ from typing import List, Optional, Type import pytest -from transformers import BatchEncoding +from transformers import AutoModelForVision2Seq, BatchEncoding from vllm.multimodal.utils import rescale_image_size from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE @@ -74,7 +74,7 @@ def process(hf_inputs: BatchEncoding): with hf_runner(model, dtype=dtype, postprocess_inputs=process, - is_vision_model=True) as hf_model: + auto_cls=AutoModelForVision2Seq) as hf_model: hf_outputs_per_image = [ hf_model.generate_greedy_logprobs_limit(prompts, max_tokens, diff --git a/tests/models/test_llava.py b/tests/models/test_llava.py index edaf7d400eb53..93634f245cee7 100644 --- a/tests/models/test_llava.py +++ b/tests/models/test_llava.py @@ -1,7 +1,8 @@ from typing import List, Optional, Tuple, Type import pytest -from transformers import AutoConfig, AutoTokenizer, BatchEncoding +from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer, + BatchEncoding) from vllm.multimodal.utils import rescale_image_size from vllm.sequence import SampleLogprobs @@ -124,7 +125,7 @@ def process(hf_inputs: BatchEncoding): with hf_runner(model, dtype=dtype, postprocess_inputs=process, - is_vision_model=True) as hf_model: + auto_cls=AutoModelForVision2Seq) as hf_model: hf_outputs_per_image = [ hf_model.generate_greedy_logprobs_limit(prompts, max_tokens, diff --git a/tests/models/test_llava_image_embeds.py b/tests/models/test_llava_image_embeds.py index 63ccd1f6625c8..cc444fe32e79b 100644 --- a/tests/models/test_llava_image_embeds.py +++ b/tests/models/test_llava_image_embeds.py @@ -1,7 +1,7 @@ from typing import List, Optional, Tuple, Type import pytest -from transformers import AutoConfig, AutoTokenizer +from transformers import AutoConfig, AutoModelForVision2Seq, AutoTokenizer from vllm.sequence import SampleLogprobs @@ -105,7 +105,8 @@ def run_test( for prompts, images in vllm_inputs_per_image ] - with hf_runner(model, dtype=dtype, is_vision_model=True) as hf_model: + with hf_runner(model, dtype=dtype, + auto_cls=AutoModelForVision2Seq) as hf_model: hf_outputs_per_image = [ hf_model.generate_greedy_logprobs_limit(prompts, max_tokens, diff --git a/tests/models/test_llava_next.py b/tests/models/test_llava_next.py index 2bd27f888680d..9cf55c0858df0 100644 --- a/tests/models/test_llava_next.py +++ b/tests/models/test_llava_next.py @@ -1,7 +1,7 @@ from typing import List, Optional, Tuple, Type, overload import pytest -from transformers import AutoConfig, AutoTokenizer +from transformers import AutoConfig, AutoModelForVision2Seq, AutoTokenizer from vllm.multimodal.utils import rescale_image_size from vllm.sequence import SampleLogprobs @@ -129,7 +129,8 @@ def run_test( for prompts, images in inputs_per_image ] - with hf_runner(model, dtype=dtype, is_vision_model=True) as hf_model: + with hf_runner(model, dtype=dtype, + auto_cls=AutoModelForVision2Seq) as hf_model: hf_outputs_per_image = [ hf_model.generate_greedy_logprobs_limit(prompts, max_tokens, diff --git a/tests/models/test_paligemma.py b/tests/models/test_paligemma.py index 038a22f71acad..beddaaf608a18 100644 --- a/tests/models/test_paligemma.py +++ b/tests/models/test_paligemma.py @@ -2,7 +2,7 @@ from typing import List, Optional, Tuple, Type import pytest -from transformers import AutoConfig, AutoTokenizer +from transformers import AutoConfig, AutoModelForVision2Seq, AutoTokenizer from vllm.multimodal.utils import rescale_image_size from vllm.sequence import SampleLogprobs @@ -102,7 +102,8 @@ def run_test( for prompts, images in inputs_per_image ] - with hf_runner(model, dtype=dtype, is_vision_model=True) as hf_model: + with hf_runner(model, dtype=dtype, + auto_cls=AutoModelForVision2Seq) as hf_model: hf_outputs_per_image = [ hf_model.generate_greedy_logprobs_limit(prompts, max_tokens, diff --git a/tests/models/test_qwen.py b/tests/models/test_qwen.py index 03605e3b34810..0f974fcc1885c 100644 --- a/tests/models/test_qwen.py +++ b/tests/models/test_qwen.py @@ -26,7 +26,7 @@ def test_text_only_qwen_model( # for qwen-vl is still unsupported in VLLM. In the near-future, the # implementation and this test will be extended to consider # visual inputs as well. - with hf_runner(model, dtype=dtype, is_vision_model=False) as hf_model: + with hf_runner(model, dtype=dtype) as hf_model: hf_outputs = hf_model.generate_greedy_logprobs_limit( example_prompts, max_tokens, diff --git a/tests/models/test_ultravox.py b/tests/models/test_ultravox.py new file mode 100644 index 0000000000000..98de10aa08408 --- /dev/null +++ b/tests/models/test_ultravox.py @@ -0,0 +1,151 @@ +from typing import List, Optional, Tuple, Type + +import librosa +import numpy as np +import pytest +from transformers import AutoModel, AutoTokenizer, BatchEncoding + +from vllm.assets.audio import AudioAsset +from vllm.sequence import SampleLogprobs +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE + +from ..conftest import HfRunner, VllmRunner +from .utils import check_logprobs_close + +pytestmark = pytest.mark.vlm + +MODEL_NAME = "fixie-ai/ultravox-v0_3" + +AudioTuple = Tuple[np.ndarray, int] + + +@pytest.fixture(scope="session") +def audio_and_sample_rate(): + return AudioAsset("mary_had_lamb").audio_and_sample_rate + + +@pytest.fixture +def prompts_and_audios(audio_and_sample_rate): + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + + vllm_placeholder = "<|reserved_special_token_0|>" + hf_placeholder = "<|audio|>" + + question = "What's in the audio?" + vllm_prompt = tokenizer.apply_chat_template( + [{ + 'role': 'user', + 'content': f"{vllm_placeholder}\n{question}" + }], + tokenize=False, + add_generation_prompt=True) + hf_prompt = tokenizer.apply_chat_template( + [{ + 'role': 'user', + 'content': f"{hf_placeholder}\n{question}" + }], + tokenize=False, + add_generation_prompt=True) + + return [(vllm_prompt, hf_prompt, audio_and_sample_rate)] + + +def vllm_to_hf_output(vllm_output: Tuple[List[int], str, + Optional[SampleLogprobs]], + model: str): + """Sanitize vllm output to be comparable with hf output.""" + output_ids, output_str, out_logprobs = vllm_output + + tokenizer = AutoTokenizer.from_pretrained(model) + eos_token_id = tokenizer.eos_token_id + + hf_output_ids = output_ids[:] + hf_output_str = output_str + if hf_output_ids[-1] == eos_token_id: + hf_output_str = hf_output_str + tokenizer.decode(eos_token_id) + + return hf_output_ids, hf_output_str, out_logprobs + + +def run_test( + hf_runner: Type[HfRunner], + vllm_runner: Type[VllmRunner], + prompts_and_audios: List[Tuple[str, str, AudioTuple]], + model: str, + *, + dtype: str, + max_tokens: int, + num_logprobs: int, + tensor_parallel_size: int, + distributed_executor_backend: Optional[str] = None, +): + """Inference result should be the same between hf and vllm.""" + torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype] + + # NOTE: take care of the order. run vLLM first, and then run HF. + # vLLM needs a fresh new process without cuda initialization. + # if we run HF first, the cuda initialization will be done and it + # will hurt multiprocessing backend with fork method (the default method). + + with vllm_runner(model, + dtype=dtype, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend, + enforce_eager=True) as vllm_model: + vllm_outputs_per_audio = [ + vllm_model.generate_greedy_logprobs([vllm_prompt], + max_tokens, + num_logprobs=num_logprobs, + audios=[audio]) + for vllm_prompt, _, audio in prompts_and_audios + ] + + def process(hf_inputs: BatchEncoding): + hf_inputs["audio_values"] = hf_inputs["audio_values"] \ + .to(torch_dtype) # type: ignore + return hf_inputs + + with hf_runner(model, + dtype=dtype, + postprocess_inputs=process, + auto_cls=AutoModel) as hf_model: + + hf_outputs_per_audio = [ + hf_model.generate_greedy_logprobs_limit( + [hf_prompt], + max_tokens, + num_logprobs=num_logprobs, + audios=[(librosa.resample(audio[0], + orig_sr=audio[1], + target_sr=16000), 16000)]) + for _, hf_prompt, audio in prompts_and_audios + ] + + for hf_outputs, vllm_outputs in zip(hf_outputs_per_audio, + vllm_outputs_per_audio): + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=[ + vllm_to_hf_output(vllm_output, model) + for vllm_output in vllm_outputs + ], + name_0="hf", + name_1="vllm", + ) + + +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_models(hf_runner, vllm_runner, prompts_and_audios, dtype: str, + max_tokens: int, num_logprobs: int) -> None: + run_test( + hf_runner, + vllm_runner, + prompts_and_audios, + MODEL_NAME, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + tensor_parallel_size=1, + ) diff --git a/vllm/assets/audio.py b/vllm/assets/audio.py new file mode 100644 index 0000000000000..b00a61ebfec65 --- /dev/null +++ b/vllm/assets/audio.py @@ -0,0 +1,26 @@ +from dataclasses import dataclass +from typing import Literal, Tuple +from urllib.parse import urljoin + +import librosa +import numpy as np + +from vllm.assets.base import get_vllm_public_assets, vLLM_S3_BUCKET_URL + +ASSET_DIR = "multimodal_asset" + + +@dataclass(frozen=True) +class AudioAsset: + name: Literal["winning_call", "mary_had_lamb"] + + @property + def audio_and_sample_rate(self) -> Tuple[np.ndarray, int]: + + audio_path = get_vllm_public_assets(filename=f"{self.name}.ogg", + s3_prefix=ASSET_DIR) + return librosa.load(audio_path, sr=None) + + @property + def url(self) -> str: + return urljoin(vLLM_S3_BUCKET_URL, f"{ASSET_DIR}/{self.name}.ogg") diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 48fd1333d8f40..19d1095084293 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -117,8 +117,8 @@ def _mm_token_str(model_config: ModelConfig, tokenizer: AnyTokenizer, modality: Literal["image", "audio"]) -> Optional[str]: # TODO: Let user specify how to insert image tokens into prompt # (similar to chat template) + model_type = model_config.hf_config.model_type if modality == "image": - model_type = model_config.hf_config.model_type if model_type == "phi3_v": # Workaround since this token is not defined in the tokenizer return "<|image_1|>" @@ -134,7 +134,9 @@ def _mm_token_str(model_config: ModelConfig, tokenizer: AnyTokenizer, raise TypeError(f"Unknown model type: {model_type}") elif modality == "audio": - raise TypeError("No audio models are supported yet.") + if model_type == "ultravox": + return "<|reserved_special_token_0|>" + raise TypeError(f"Unknown model type: {model_type}") else: raise TypeError(f"Unknown modality: {modality}") diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 32cafa845a6e3..bdf6e502ea112 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -61,7 +61,7 @@ "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"), "MedusaModel": ("medusa", "Medusa"), "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"), - "JambaForCausalLM": ("jamba", "JambaForCausalLM") + "JambaForCausalLM": ("jamba", "JambaForCausalLM"), } _EMBEDDING_MODELS = { @@ -83,6 +83,7 @@ "PaliGemmaForConditionalGeneration": ("paligemma", "PaliGemmaForConditionalGeneration"), "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), + "UltravoxModel": ("ultravox", "UltravoxModel"), } _CONDITIONAL_GENERATION_MODELS = { "BartModel": ("bart", "BartForConditionalGeneration"), diff --git a/vllm/model_executor/models/blip.py b/vllm/model_executor/models/blip.py index 69e777152e3d4..830680fd990bf 100644 --- a/vllm/model_executor/models/blip.py +++ b/vllm/model_executor/models/blip.py @@ -15,8 +15,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.multimodal.image import (cached_get_tokenizer, - repeat_and_pad_image_tokens) +from vllm.multimodal.utils import (cached_get_tokenizer, + repeat_and_pad_placeholder_tokens) from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData @@ -97,11 +97,11 @@ def input_processor_for_blip( else: image_feature_size = image_feature_size_override - new_prompt, new_token_ids = repeat_and_pad_image_tokens( + new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( tokenizer, llm_inputs.get("prompt"), llm_inputs["prompt_token_ids"], - image_token_id=image_token_id, + placeholder_token_id=image_token_id, repeat_count=image_feature_size, ) diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index 788d22db9d5a8..a335e1766b2a9 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -30,8 +30,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.image import (cached_get_tokenizer, - repeat_and_pad_image_tokens) +from vllm.multimodal.utils import (cached_get_tokenizer, + repeat_and_pad_placeholder_tokens) from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, SamplerOutput, SequenceData) from vllm.utils import print_warning_once @@ -124,11 +124,11 @@ def input_processor_for_chameleon(ctx: InputContext, llm_inputs: LLMInputs): model_config = ctx.model_config tokenizer = cached_get_tokenizer(model_config.tokenizer) - new_prompt, new_token_ids = repeat_and_pad_image_tokens( + new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( tokenizer, llm_inputs.get("prompt"), llm_inputs["prompt_token_ids"], - image_token_id=CHAMELEON_IMAGE_TOKEN_ID, + placeholder_token_id=CHAMELEON_IMAGE_TOKEN_ID, repeat_count=CHAMELEON_IMAGE_SEQ_LENGTH, pad_token_left=CHAMELEON_IMAGE_START_TOKEN_ID, pad_token_right=CHAMELEON_IMAGE_END_TOKEN_ID, diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py index 24eeefdfccf00..0933966055330 100644 --- a/vllm/model_executor/models/clip.py +++ b/vllm/model_executor/models/clip.py @@ -16,8 +16,8 @@ RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.multimodal.image import (cached_get_tokenizer, - repeat_and_pad_image_tokens) +from vllm.multimodal.utils import (cached_get_tokenizer, + repeat_and_pad_placeholder_tokens) from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData @@ -103,11 +103,11 @@ def input_processor_for_clip( else: image_feature_size = image_feature_size_override - new_prompt, new_token_ids = repeat_and_pad_image_tokens( + new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( tokenizer, llm_inputs.get("prompt"), llm_inputs["prompt_token_ids"], - image_token_id=image_token_id, + placeholder_token_id=image_token_id, repeat_count=image_feature_size, ) diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index 2ef23819b69a2..cfc2a5288a37b 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -36,8 +36,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.base import MultiModalInputs -from vllm.multimodal.image import (cached_get_image_processor, - cached_get_tokenizer) +from vllm.multimodal.image import cached_get_image_processor +from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, SamplerOutput, SequenceData) diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index b379c86c1912b..c996f0b73f293 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -23,7 +23,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.base import MultiModalInputs -from vllm.multimodal.image import cached_get_tokenizer +from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import IntermediateTensors, SamplerOutput from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip, diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 99a3c5dab39e4..29f3640e2458b 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -54,8 +54,8 @@ from vllm.model_executor.models.qwen2 import Qwen2Model from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.image import (cached_get_image_processor, - cached_get_tokenizer) +from vllm.multimodal.image import cached_get_image_processor +from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, SamplerOutput, SequenceData) diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index 8beb2778fe37a..8cb5065ed79ec 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -16,7 +16,7 @@ from vllm.model_executor.models.gemma import GemmaModel from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.image import cached_get_tokenizer +from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import IntermediateTensors, SamplerOutput from .interfaces import SupportsMultiModal diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 328f4e6fa827c..9ccd6ef6d9ace 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -37,7 +37,7 @@ from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.image import cached_get_tokenizer +from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import IntermediateTensors, SamplerOutput from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip, diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index 426af7fee9544..7f6186fa010a4 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -24,8 +24,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.multimodal.image import (cached_get_tokenizer, - repeat_and_pad_image_tokens) +from vllm.multimodal.utils import (cached_get_tokenizer, + repeat_and_pad_placeholder_tokens) from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData @@ -112,11 +112,11 @@ def input_processor_for_siglip( else: image_feature_size = image_feature_size_override - new_prompt, new_token_ids = repeat_and_pad_image_tokens( + new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( tokenizer, llm_inputs.get("prompt"), llm_inputs["prompt_token_ids"], - image_token_id=image_token_id, + placeholder_token_id=image_token_id, repeat_count=image_feature_size, ) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py new file mode 100644 index 0000000000000..842264f765866 --- /dev/null +++ b/vllm/model_executor/models/ultravox.py @@ -0,0 +1,435 @@ +# Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py +"""PyTorch Ultravox model.""" + +import itertools +import math +from array import array +from functools import lru_cache +from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, + TypedDict, Union, cast) + +import librosa +import numpy as np +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import functional as F +from transformers.models.whisper import WhisperFeatureExtractor +from transformers.models.whisper.modeling_whisper import WhisperEncoder + +from vllm.attention import AttentionMetadata +from vllm.config import CacheConfig, MultiModalConfig +from vllm.inputs import INPUT_REGISTRY +from vllm.inputs.data import LLMInputs +from vllm.inputs.registry import InputContext +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.interfaces import SupportsMultiModal +from vllm.model_executor.models.utils import (filter_weights, + init_vllm_registered_model, + merge_multimodal_embeddings) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.base import MultiModalInputs +from vllm.multimodal.utils import (cached_get_tokenizer, + repeat_and_pad_placeholder_tokens) +from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SamplerOutput, SequenceData +from vllm.transformers_utils.configs.ultravox import UltravoxConfig + +_AUDIO_PLACEHOLDER_TOKEN = 128002 +_AUDIO_TOKENS_PER_SECOND = 6.25 + +logger = init_logger(__name__) + + +class UltravoxAudioFeatureInputs(TypedDict): + type: Literal["audio_features"] + data: Union[torch.Tensor, List[torch.Tensor]] + """Shape: `(batch_size, 80, M)""" + + +class UltravoxAudioEmbeddingInputs(TypedDict): + type: Literal["audio_embeds"] + data: torch.Tensor + + +UltravoxAudioInputs = Union[UltravoxAudioFeatureInputs, + UltravoxAudioEmbeddingInputs] + + +@lru_cache +def cached_feature_extractor(model_id: str) -> WhisperFeatureExtractor: + return WhisperFeatureExtractor.from_pretrained(model_id) + + +def whisper_feature_extractor(ctx: InputContext) -> WhisperFeatureExtractor: + return cached_feature_extractor( + ctx.get_hf_config(UltravoxConfig).audio_model_id) + + +def get_ultravox_max_audio_tokens(ctx: InputContext): + feature_extractor = whisper_feature_extractor(ctx) + return math.ceil(feature_extractor.chunk_length * _AUDIO_TOKENS_PER_SECOND) + + +def dummy_data_for_ultravox( + ctx: InputContext, + seq_len: int, + mm_counts: Mapping[str, int], +): + feature_extractor = whisper_feature_extractor(ctx) + + audio_count = mm_counts["audio"] + + audio_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [ + _AUDIO_PLACEHOLDER_TOKEN + ]) * get_ultravox_max_audio_tokens(ctx) * audio_count + other_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, + [0]) * (seq_len - len(audio_token_ids)) + + audio_and_sr = (np.array([0.0] * feature_extractor.chunk_length), 1) + mm_dict = { + "audio": + audio_and_sr if audio_count == 1 else [audio_and_sr] * audio_count + } + + return (SequenceData(audio_token_ids + other_token_ids), mm_dict) + + +def input_mapper_for_ultravox(ctx: InputContext, data: object): + if isinstance(data, tuple): + (audio, sr) = cast(Tuple[np.ndarray, Union[float, int]], data) + feature_extractor = whisper_feature_extractor(ctx) + + if sr != feature_extractor.sampling_rate: + audio = librosa.resample(audio, + orig_sr=sr, + target_sr=feature_extractor.sampling_rate) + sr = feature_extractor.sampling_rate + + minimum_audio_length = feature_extractor.n_fft // 2 + 1 + if len(audio) < minimum_audio_length: + # Not enough audio; pad it. + audio = np.pad(audio, (0, minimum_audio_length - len(audio))) + + return MultiModalInputs({ + "audio_features": + feature_extractor(audio, + sampling_rate=sr, + padding="longest", + return_tensors="pt")["input_features"] + }) + + raise NotImplementedError(f"Unsupported data type: {type(data)}") + + +def input_processor_for_ultravox(ctx: InputContext, llm_inputs: LLMInputs): + multi_modal_data = llm_inputs.get("multi_modal_data") + if multi_modal_data is None or "audio" not in multi_modal_data: + return llm_inputs + + feature_extractor = whisper_feature_extractor(ctx) + audio_data, sample_rate = multi_modal_data["audio"] + + audio_length = audio_data.shape[0] + if sample_rate != feature_extractor.sampling_rate: + # Account for resampling. + adjustment = feature_extractor.sampling_rate / sample_rate + audio_length = math.ceil(adjustment * audio_length) + + feature_extractor_output_length = math.ceil( + (audio_length - + (feature_extractor.hop_length - 1)) / feature_extractor.hop_length) + + uv_config = ctx.get_hf_config(UltravoxConfig) + audio_num_tokens = min( + max( + 1, + math.ceil(feature_extractor_output_length / + (uv_config.stack_factor * 2))), + get_ultravox_max_audio_tokens(ctx)) + tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer) + + new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( + tokenizer, + llm_inputs.get("prompt"), + llm_inputs["prompt_token_ids"], + placeholder_token_id=_AUDIO_PLACEHOLDER_TOKEN, + repeat_count=audio_num_tokens, + ) + + # NOTE: Create a defensive copy of the original inputs + return LLMInputs(prompt_token_ids=new_token_ids, + prompt=new_prompt, + multi_modal_data=multi_modal_data) + + +class StackAudioFrames(nn.Module): + """ + Stack the audio embedding frames to reduce the sequence length by a factor + of `stack_factor`. + """ + + def __init__(self, stack_factor: int = 8): + super().__init__() + self.stack_factor = stack_factor + + def forward(self, audio_embeds: torch.Tensor) -> torch.Tensor: + B, T, C = audio_embeds.shape + T_pad = (T + self.stack_factor - + 1) // self.stack_factor * self.stack_factor + audio_embeds = F.pad(audio_embeds, (0, 0, 0, T_pad - T)) + B, T, C = audio_embeds.shape + audio_embeds = audio_embeds.view(B, T // self.stack_factor, + C * self.stack_factor) + return audio_embeds + + +class FlippedSiluAndMul(SiluAndMul): + """Ultravox is trained with SwiGLU with flipped halves.""" + + def forward(self, x: torch.Tensor): + a, b = x.chunk(2, dim=-1) + flipped = torch.cat((b, a), dim=-1) + return super().forward(flipped) + + +class UltravoxProjector(nn.Module): + + def __init__(self, config: UltravoxConfig): + super().__init__() + self.hidden_dim = config.hidden_size + self._pad_and_stack = StackAudioFrames(config.stack_factor) + dim = config.audio_config.hidden_size * config.stack_factor + self.ln_pre = RMSNorm(dim) + self.linear_1 = nn.Linear(dim, self.hidden_dim, bias=False) + dim = self.hidden_dim + + if config.projector_act == "swiglu": + self.act = FlippedSiluAndMul() + dim = dim // 2 + else: + self.act = get_act_fn(config.projector_act) + + self.linear_2 = nn.Linear(dim, + config.text_config.hidden_size, + bias=False) + self.ln_post = RMSNorm(config.text_config.hidden_size) + + def forward(self, audio_features: torch.Tensor) -> torch.Tensor: + audio_features = self._pad_and_stack(audio_features) + audio_features = self.ln_pre(audio_features) + hidden_states = self.linear_1(audio_features) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + hidden_states = self.ln_post(hidden_states) + return hidden_states + + +class ModifiedWhisperEncoder(WhisperEncoder): + """ + Encoder portion of OpenAI's Whisper model. + + This implementation is a slightly modified version of HF Transformers' + Whisper Encoder, with only a few fixes: + 1. base_model_prefix updated to allow for doing `.from_pretrained` + directly on the encoder + 2. allow less than 30 second of audio padding to be passed in: + - relaxed ValueError check for `input_features` length to be less + than or equal to `expected_seq_length` instead of strictly equal + - embed_pos is now sliced to match the length of `inputs_embeds` + + Original: https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/modeling_whisper.py + See commentary: https://github.com/huggingface/transformers/issues/25744 + """ + + base_model_prefix = "model.encoder" + + def forward( + self, + input_features, + ): + expected_seq_length = (self.config.max_source_positions * + self.conv1.stride[0] * self.conv2.stride[0]) + if input_features.shape[-1] > expected_seq_length: + raise ValueError( + f"Whisper expects the mel input features to be of length " + f"{expected_seq_length} or less, but found " + f"{input_features.shape[-1]}. Make sure to pad the input mel " + f"features to {expected_seq_length}.") + + inputs_embeds = nn.functional.gelu(self.conv1(input_features)) + inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds)) + + inputs_embeds = inputs_embeds.permute(0, 2, 1) + embed_pos = self.embed_positions.weight[:inputs_embeds.size(-2)] + + hidden_states = inputs_embeds + embed_pos + hidden_states = nn.functional.dropout(hidden_states, + p=self.dropout, + training=self.training) + + for encoder_layer in self.layers: + layer_outputs = encoder_layer( + hidden_states, + None, + layer_head_mask=None, + ) + + hidden_states = layer_outputs[0] + + hidden_states = self.layer_norm(hidden_states) + return hidden_states + + +@MULTIMODAL_REGISTRY.register_input_mapper("audio", input_mapper_for_ultravox) +@MULTIMODAL_REGISTRY.register_max_multimodal_tokens( + "audio", get_ultravox_max_audio_tokens) +@INPUT_REGISTRY.register_dummy_data(dummy_data_for_ultravox) +@INPUT_REGISTRY.register_input_processor(input_processor_for_ultravox) +class UltravoxModel(nn.Module, SupportsMultiModal): + + def __init__(self, + config: UltravoxConfig, + multimodal_config: MultiModalConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional["QuantizationConfig"] = None): + super().__init__() + self.config = config + self.multi_modal_config = multimodal_config + assert self.multi_modal_config + + if config.audio_model_id is not None: + self.audio_tower = ModifiedWhisperEncoder.from_pretrained( + config.audio_model_id) + else: + self.audio_tower = ModifiedWhisperEncoder(config.audio_config) + self.multi_modal_projector = UltravoxProjector(config) + self.language_model = init_vllm_registered_model( + config.text_config, cache_config, quant_config) + + def _audio_features_to_embeddings( + self, input_features: torch.Tensor) -> torch.Tensor: + audio_input = input_features.to(self.audio_tower.dtype) + audio_features = self.audio_tower(audio_input) + audio_features = audio_features.to(self.audio_tower.dtype) + audio_embeddings = self.multi_modal_projector(audio_features) + return audio_embeddings + + def _parse_and_validate_audio_input( + self, **kwargs: object) -> Optional[UltravoxAudioInputs]: + audio_features = kwargs.pop("audio_features", None) + audio_embeds = kwargs.pop("audio_embeds", None) + + if audio_features is None and audio_embeds is None: + return None + + if audio_features is not None: + if not isinstance(audio_features, (torch.Tensor, list)): + raise ValueError("Incorrect type of audio features. " + f"Got type: {type(audio_features)}") + + return UltravoxAudioFeatureInputs(type="audio_features", + data=audio_features) + + if audio_embeds is not None: + if not isinstance(audio_embeds, torch.Tensor): + raise ValueError("Incorrect type of audio embeds. " + f"Got type: {type(audio_embeds)}") + + return UltravoxAudioEmbeddingInputs(type="audio_embeds", + data=audio_embeds) + + raise AssertionError("This line should be unreachable.") + + def _process_audio_input( + self, audio_input: UltravoxAudioInputs + ) -> Union[torch.Tensor, List[torch.Tensor]]: + if audio_input["type"] == "audio_embeds": + return audio_input["data"] + + audio_features = audio_input["data"] + if isinstance(audio_features, list): + # TODO: Batch these through the encoder/projector instead of + # serializing them. + return [ + self._audio_features_to_embeddings( + features.unsqueeze(0)).squeeze(0) + for features in audio_features + ] + else: + return self._audio_features_to_embeddings(audio_features) + + def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[torch.Tensor], + **kwargs) -> SamplerOutput: + """Run forward pass for Ultravox + + One key thing to understand is the `input_ids` already accounts for the + positions of the to-be-inserted audio embeddings. The to-be-inserted + audio has a size that is essentially 6.25 tokens per second of audio. + + This way, the `positions` and `attn_metadata` are consistent + with the `input_ids`. + + Args: + input_features: A batch of audio inputs, [1, 80, M]. + """ + audio_input = self._parse_and_validate_audio_input(**kwargs) + if audio_input is not None: + audio_embeddings = self._process_audio_input(audio_input) + inputs_embeds = self.language_model.model.get_input_embeddings( + input_ids) + + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, audio_embeddings, + _AUDIO_PLACEHOLDER_TOKEN) + input_ids = None + else: + inputs_embeds = None + + hidden_states = self.language_model.model( + input_ids=input_ids, + positions=positions, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds) + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + return self.language_model.compute_logits(hidden_states, + sampling_metadata) + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + return self.language_model.sample(logits, sampling_metadata) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + # prepare weight iterators for components + projector_weights, llm_weights = itertools.tee(weights, 2) + + # load projector weights + projector_weights = filter_weights(projector_weights, + "multi_modal_projector") + projector_params_dict = dict( + self.multi_modal_projector.named_parameters()) + for name, loaded_weight in projector_weights: + param = projector_params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + # load llm backbone + llm_weights = filter_weights(llm_weights, "language_model") + self.language_model.load_weights(llm_weights) diff --git a/vllm/multimodal/image.py b/vllm/multimodal/image.py index 916bd5e601bb7..6cdde949bc2b1 100644 --- a/vllm/multimodal/image.py +++ b/vllm/multimodal/image.py @@ -1,5 +1,4 @@ from functools import lru_cache -from typing import List, Optional, Tuple, TypeVar import torch from PIL import Image @@ -8,7 +7,6 @@ from vllm.inputs.registry import InputContext from vllm.logger import init_logger from vllm.transformers_utils.image_processor import get_image_processor -from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer from vllm.utils import is_list_of from .base import MultiModalData, MultiModalInputs, MultiModalPlugin @@ -16,87 +14,6 @@ logger = init_logger(__name__) cached_get_image_processor = lru_cache(get_image_processor) -cached_get_tokenizer = lru_cache(get_tokenizer) - -# Utilities for image input processors -_T = TypeVar("_T", str, int) - - -def repeat_and_pad_token( - token: _T, - *, - repeat_count: int = 1, - pad_token_left: Optional[_T] = None, - pad_token_right: Optional[_T] = None, -) -> List[_T]: - replacement = [token] * repeat_count - if pad_token_left is not None: - replacement = [pad_token_left] + replacement - if pad_token_right is not None: - replacement = replacement + [pad_token_right] - - return replacement - - -def repeat_and_pad_image_tokens( - tokenizer: AnyTokenizer, - prompt: Optional[str], - prompt_token_ids: List[int], - *, - image_token_id: int, - repeat_count: int = 1, - pad_token_left: Optional[int] = None, - pad_token_right: Optional[int] = None, -) -> Tuple[Optional[str], List[int]]: - if prompt is None: - new_prompt = None - else: - image_token_str = tokenizer.decode(image_token_id) - pad_token_str_left = (None if pad_token_left is None else - tokenizer.decode(pad_token_left)) - pad_token_str_right = (None if pad_token_right is None else - tokenizer.decode(pad_token_right)) - replacement_str = "".join( - repeat_and_pad_token( - image_token_str, - repeat_count=repeat_count, - pad_token_left=pad_token_str_left, - pad_token_right=pad_token_str_right, - )) - - image_token_count = prompt.count(image_token_str) - # This is an arbitrary number to distinguish between the two cases - if image_token_count > 16: - logger.warning( - "Please follow the prompt format that is " - "documented on HuggingFace which does not involve " - "repeating %s tokens.", image_token_str) - elif image_token_count > 1: - logger.warning("Multiple image input is not supported yet, " - "so any extra image tokens will be treated " - "as plain text.") - - # The image tokens are removed to be consistent with HuggingFace - new_prompt = prompt.replace(image_token_str, replacement_str, 1) - - new_token_ids: List[int] = [] - for i, token in enumerate(prompt_token_ids): - if token == image_token_id: - replacement_ids = repeat_and_pad_token( - image_token_id, - repeat_count=repeat_count, - pad_token_left=pad_token_left, - pad_token_right=pad_token_right, - ) - new_token_ids.extend(replacement_ids) - - # No need to further scan the list since we only replace once - new_token_ids.extend(prompt_token_ids[i + 1:]) - break - else: - new_token_ids.append(token) - - return new_prompt, new_token_ids class ImagePlugin(MultiModalPlugin): diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index d1e624cdb8ace..3bf430235462b 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -1,6 +1,7 @@ import base64 +from functools import lru_cache from io import BytesIO -from typing import Tuple, Union +from typing import List, Optional, Tuple, TypeVar, Union import librosa import numpy as np @@ -9,7 +10,13 @@ from vllm.connections import global_http_connection from vllm.envs import VLLM_AUDIO_FETCH_TIMEOUT, VLLM_IMAGE_FETCH_TIMEOUT +from vllm.logger import init_logger from vllm.multimodal.base import MultiModalDataDict +from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer + +logger = init_logger(__name__) + +cached_get_tokenizer = lru_cache(get_tokenizer) def _load_image_from_bytes(b: bytes): @@ -154,3 +161,84 @@ def rescale_image_size(image: Image.Image, if transpose >= 0: image = image.transpose(Image.Transpose(transpose)) return image + + +# Utilities for input processors +_T = TypeVar("_T", str, int) + + +def repeat_and_pad_token( + token: _T, + *, + repeat_count: int = 1, + pad_token_left: Optional[_T] = None, + pad_token_right: Optional[_T] = None, +) -> List[_T]: + replacement = [token] * repeat_count + if pad_token_left is not None: + replacement = [pad_token_left] + replacement + if pad_token_right is not None: + replacement = replacement + [pad_token_right] + + return replacement + + +def repeat_and_pad_placeholder_tokens( + tokenizer: AnyTokenizer, + prompt: Optional[str], + prompt_token_ids: List[int], + *, + placeholder_token_id: int, + repeat_count: int = 1, + pad_token_left: Optional[int] = None, + pad_token_right: Optional[int] = None, +) -> Tuple[Optional[str], List[int]]: + if prompt is None: + new_prompt = None + else: + placeholder_token_str = tokenizer.decode(placeholder_token_id) + pad_token_str_left = (None if pad_token_left is None else + tokenizer.decode(pad_token_left)) + pad_token_str_right = (None if pad_token_right is None else + tokenizer.decode(pad_token_right)) + replacement_str = "".join( + repeat_and_pad_token( + placeholder_token_str, + repeat_count=repeat_count, + pad_token_left=pad_token_str_left, + pad_token_right=pad_token_str_right, + )) + + placeholder_token_count = prompt.count(placeholder_token_str) + # This is an arbitrary number to distinguish between the two cases + if placeholder_token_count > 16: + logger.warning( + "Please follow the prompt format that is " + "documented on HuggingFace which does not involve " + "repeating %s tokens.", placeholder_token_str) + elif placeholder_token_count > 1: + logger.warning("Multiple multi-modal input is not supported yet, " + "so any extra placeholder tokens will be treated " + "as plain text.") + + # The image tokens are removed to be consistent with HuggingFace + new_prompt = prompt.replace(placeholder_token_str, replacement_str, 1) + + new_token_ids: List[int] = [] + for i, token in enumerate(prompt_token_ids): + if token == placeholder_token_id: + replacement_ids = repeat_and_pad_token( + placeholder_token_id, + repeat_count=repeat_count, + pad_token_left=pad_token_left, + pad_token_right=pad_token_right, + ) + new_token_ids.extend(replacement_ids) + + # No need to further scan the list since we only replace once + new_token_ids.extend(prompt_token_ids[i + 1:]) + break + else: + new_token_ids.append(token) + + return new_prompt, new_token_ids diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 5f04b39ef524e..d3024965c0b4c 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -12,7 +12,7 @@ InternVLChatConfig, JAISConfig, MedusaConfig, MLPSpeculatorConfig, MPTConfig, NemotronConfig, - RWConfig) + RWConfig, UltravoxConfig) if VLLM_USE_MODELSCOPE: from modelscope import AutoConfig @@ -32,6 +32,7 @@ "medusa": MedusaConfig, "internvl_chat": InternVLChatConfig, "nemotron": NemotronConfig, + "ultravox": UltravoxConfig, } for name, cls in _CONFIG_REGISTRY.items(): diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 5ccacd4a4c40a..22b906a3149ec 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -10,6 +10,7 @@ from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig from vllm.transformers_utils.configs.mpt import MPTConfig from vllm.transformers_utils.configs.nemotron import NemotronConfig +from vllm.transformers_utils.configs.ultravox import UltravoxConfig __all__ = [ "ChatGLMConfig", @@ -21,4 +22,5 @@ "MedusaConfig", "MLPSpeculatorConfig", "NemotronConfig", + "UltravoxConfig", ] diff --git a/vllm/transformers_utils/configs/ultravox.py b/vllm/transformers_utils/configs/ultravox.py new file mode 100644 index 0000000000000..f724bf7f2f1cd --- /dev/null +++ b/vllm/transformers_utils/configs/ultravox.py @@ -0,0 +1,99 @@ +# Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_config.py +from typing import Any, Dict, Optional + +import transformers + + +class UltravoxConfig(transformers.PretrainedConfig): + r""" + This is the configuration class to store the configuration of a + [`UltravoxForConditionalGeneration`]. It is used to instantiate an + Ultravox model according to the specified arguments, defining the model + architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to + control the model outputs. Read the documentation from [`PretrainedConfig`] + for more information. + + Args: + audio_config (`Union[AutoConfig, dict]`, *optional*): + Custom audio config or dict + text_config (`Union[AutoConfig, dict]`, *optional*): + The config object of the text backbone. Can be any of `LlamaConfig` + or `MistralConfig`. + ignore_index (`int`, *optional*, defaults to -100): + The ignore index for the loss function. + audio_token_index (`int`, *optional*, defaults to 32000): + The audio token index to encode the audio prompt. + stack_factor (`int`, *optional*, defaults to 8): + Audio downsampling factor for the multimodal projector. + norm_init (`float`, *optional*, defaults to 0.4): + The initialization value for the layer normalization. + projector_act (`str`, *optional*, defaults to `"swiglu"`): + The activation function used by the multimodal projector. + text_model_lora_config (`LoraConfigSimplified`, *optional*): + The LoRA configuration for finetuning the text model. + audio_model_lora_config (`LoraConfigSimplified`, *optional*): + The LoRA configuration for finetuning the audio model. + """ + + model_type = "ultravox" + is_composition = False + + def __init__( + self, + audio_config: Optional[Dict[str, Any]] = None, + text_config: Optional[Dict[str, Any]] = None, + audio_model_id: Optional[str] = None, + text_model_id: Optional[str] = None, + ignore_index: int = -100, + audio_token_index: int = 32000, + hidden_size: int = 4096, + stack_factor: int = 8, + norm_init: float = 0.4, + projector_act: str = "swiglu", + text_model_lora_config: Optional[Dict[str, Any]] = None, + audio_model_lora_config: Optional[Dict[str, Any]] = None, + **kwargs, + ): + self.ignore_index = ignore_index + + self.audio_model_id = audio_model_id + self.text_model_id = text_model_id + self.audio_token_index = audio_token_index + + self.hidden_size = hidden_size + self.stack_factor = stack_factor + self.norm_init = norm_init + self.projector_act = projector_act + + if text_model_id is not None: + # Avoid circular import + from vllm.transformers_utils.config import get_config + + self.text_config = get_config(text_model_id, + trust_remote_code=False) + else: + text_config = text_config or {} + self.text_config = transformers.CONFIG_MAPPING[text_config.get( + "model_type", "llama")](**text_config) + + if audio_model_id is not None: + # Avoid circular import + from vllm.transformers_utils.config import get_config + + self.audio_config = get_config(audio_model_id, + trust_remote_code=False) + else: + audio_config = audio_config or {} + self.audio_config = transformers.CONFIG_MAPPING[audio_config.get( + "model_type", "whisper")](**audio_config) + + self.text_model_lora_config = text_model_lora_config or {} + self.audio_model_lora_config = audio_model_lora_config or {} + + self.vocab_size = self.text_config.vocab_size + + self.initializer_range = self.text_config.initializer_range + + super().__init__(**kwargs)