From 3054191873d93668ccdb101216188d7a18499700 Mon Sep 17 00:00:00 2001 From: Peter Salas Date: Fri, 16 Aug 2024 22:37:06 +0000 Subject: [PATCH] Add UltravoxModel and UltravoxConfig --- tests/entrypoints/openai/test_audio.py | 141 +------ vllm/entrypoints/chat_utils.py | 6 +- vllm/model_executor/models/__init__.py | 3 +- vllm/model_executor/models/blip.py | 8 +- vllm/model_executor/models/chameleon.py | 8 +- vllm/model_executor/models/clip.py | 8 +- vllm/model_executor/models/fuyu.py | 4 +- vllm/model_executor/models/internvl.py | 2 +- vllm/model_executor/models/paligemma.py | 2 +- vllm/model_executor/models/phi3v.py | 2 +- vllm/model_executor/models/siglip.py | 8 +- vllm/model_executor/models/ultravox.py | 438 ++++++++++++++++++++ vllm/multimodal/image.py | 83 ---- vllm/multimodal/utils.py | 92 +++- vllm/transformers_utils/config.py | 3 +- vllm/transformers_utils/configs/__init__.py | 2 + vllm/transformers_utils/configs/ultravox.py | 99 +++++ vllm/transformers_utils/processor.py | 33 ++ 18 files changed, 713 insertions(+), 229 deletions(-) create mode 100644 vllm/model_executor/models/ultravox.py create mode 100644 vllm/transformers_utils/configs/ultravox.py create mode 100644 vllm/transformers_utils/processor.py diff --git a/tests/entrypoints/openai/test_audio.py b/tests/entrypoints/openai/test_audio.py index 3c2c652fd317d..02148edc10162 100644 --- a/tests/entrypoints/openai/test_audio.py +++ b/tests/entrypoints/openai/test_audio.py @@ -1,134 +1,35 @@ -import math -import sys -import time -from typing import Dict, List, Optional, Tuple, Union, cast -from unittest.mock import patch - -import librosa -import numpy as np +from typing import Dict, List + import openai import pytest -import requests -import torch - -from vllm import ModelRegistry -from vllm.config import MultiModalConfig -from vllm.inputs import INPUT_REGISTRY -from vllm.inputs.data import LLMInputs -from vllm.inputs.registry import InputContext -from vllm.model_executor.models.interfaces import SupportsMultiModal -from vllm.model_executor.models.opt import OPTForCausalLM -from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.base import MultiModalInputs -from vllm.multimodal.image import (cached_get_tokenizer, - repeat_and_pad_image_tokens) -from vllm.multimodal.utils import encode_audio_base64, fetch_audio -from vllm.utils import get_open_port -from ...utils import VLLM_PATH +from vllm.multimodal.utils import encode_audio_base64, fetch_audio -chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja" -assert chatml_jinja_path.exists() +from ...utils import RemoteOpenAIServer -MODEL_NAME = "facebook/opt-125m" +MODEL_NAME = "fixie-ai/ultravox-v0_3" TEST_AUDIO_URLS = [ "https://upload.wikimedia.org/wikipedia/en/b/bf/Dave_Niehaus_Winning_Call_1995_AL_Division_Series.ogg", ] -def server_function(port): - - def fake_input_mapper(ctx: InputContext, data: object): - assert isinstance(data, tuple) - (audio, sr) = cast(Tuple[np.ndarray, Union[float, int]], data) - - # Resample it to 1 sample per second - audio = librosa.resample(audio, orig_sr=sr, target_sr=1) - return MultiModalInputs({"processed_audio": torch.from_numpy(audio)}) - - def fake_input_processor(ctx: InputContext, llm_inputs: LLMInputs): - multi_modal_data = llm_inputs.get("multi_modal_data") - if multi_modal_data is None or "audio" not in multi_modal_data: - return llm_inputs - - audio, sr = multi_modal_data.get("audio") - audio_duration = math.ceil(len(audio) / sr) - - new_prompt, new_token_ids = repeat_and_pad_image_tokens( - cached_get_tokenizer(ctx.model_config.tokenizer), - llm_inputs.get("prompt"), - llm_inputs["prompt_token_ids"], - image_token_id=62, # "_" - repeat_count=audio_duration) - - return LLMInputs(prompt_token_ids=new_token_ids, - prompt=new_prompt, - multi_modal_data=multi_modal_data) - - @MULTIMODAL_REGISTRY.register_input_mapper("audio", fake_input_mapper) - @MULTIMODAL_REGISTRY.register_max_multimodal_tokens( - "audio", lambda *_, **__: 100) - @INPUT_REGISTRY.register_input_processor(fake_input_processor) - class FakeAudioModel(OPTForCausalLM, SupportsMultiModal): - - def __init__(self, *args, multimodal_config: MultiModalConfig, - **kwargs): - assert multimodal_config is not None - super().__init__(*args, **kwargs) - - def forward( - self, - *args, - processed_audio: Optional[torch.Tensor] = None, - **kwargs, - ) -> torch.Tensor: - return super().forward(*args, **kwargs) - - ModelRegistry.register_model("OPTForCausalLM", FakeAudioModel) - - with patch("vllm.entrypoints.chat_utils._mm_token_str", - lambda *_, **__: "_"): - sys.argv = ["placeholder.py"] + \ - (f"--model {MODEL_NAME} --gpu-memory-utilization 0.10 " - "--dtype bfloat16 --enforce-eager --api-key token-abc123 " - f"--port {port} --chat-template {chatml_jinja_path} " - "--disable-frontend-multiprocessing").split() - import runpy - runpy.run_module('vllm.entrypoints.openai.api_server', - run_name='__main__') +@pytest.fixture(scope="module") +def server(): + args = [ + "--dtype", + "bfloat16", + "--max-model-len", + "4096", + "--enforce-eager", + ] + + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server @pytest.fixture(scope="module") -def client(): - port = get_open_port() - ctx = torch.multiprocessing.get_context("spawn") - server = ctx.Process(target=server_function, args=(port, )) - server.start() - MAX_SERVER_START_WAIT_S = 60 - client = openai.AsyncOpenAI( - base_url=f"http://localhost:{port}/v1", - api_key="token-abc123", - ) - # run health check - health_url = f"http://localhost:{port}/health" - start = time.time() - while True: - try: - if requests.get(health_url).status_code == 200: - break - except Exception as err: - result = server.exitcode - if result is not None: - raise RuntimeError("Server exited unexpectedly.") from err - - time.sleep(0.5) - if time.time() - start > MAX_SERVER_START_WAIT_S: - raise RuntimeError("Server failed to start in time.") from err - - try: - yield client - finally: - server.kill() +def client(server): + return server.get_async_client() @pytest.fixture(scope="session") @@ -172,7 +73,7 @@ async def test_single_chat_session_audio(client: openai.AsyncOpenAI, choice = chat_completion.choices[0] assert choice.finish_reason == "length" assert chat_completion.usage == openai.types.CompletionUsage( - completion_tokens=10, prompt_tokens=36, total_tokens=46) + completion_tokens=10, prompt_tokens=202, total_tokens=212) message = choice.message message = chat_completion.choices[0].message @@ -227,7 +128,7 @@ async def test_single_chat_session_audio_base64encoded( choice = chat_completion.choices[0] assert choice.finish_reason == "length" assert chat_completion.usage == openai.types.CompletionUsage( - completion_tokens=10, prompt_tokens=36, total_tokens=46) + completion_tokens=10, prompt_tokens=202, total_tokens=212) message = choice.message message = chat_completion.choices[0].message diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 4a0b0f879e8ef..a03850a654bf2 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -118,8 +118,8 @@ def _mm_token_str(model_config: ModelConfig, tokenizer: PreTrainedTokenizer, modality: Literal["image", "audio"]) -> Optional[str]: # TODO: Let user specify how to insert image tokens into prompt # (similar to chat template) + model_type = model_config.hf_config.model_type if modality == "image": - model_type = model_config.hf_config.model_type if model_type == "phi3_v": # Workaround since this token is not defined in the tokenizer return "<|image_1|>" @@ -135,7 +135,9 @@ def _mm_token_str(model_config: ModelConfig, tokenizer: PreTrainedTokenizer, raise TypeError(f"Unknown model type: {model_type}") elif modality == "audio": - raise TypeError("No audio models are supported yet.") + if model_type == "ultravox": + return "<|reserved_special_token_0|>" + raise TypeError(f"Unknown model type: {model_type}") else: raise TypeError(f"Unknown modality: {modality}") diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 46aa62e24e8af..ee6762dba0cf1 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -76,7 +76,8 @@ "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"), "MedusaModel": ("medusa", "Medusa"), "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"), - "JambaForCausalLM": ("jamba", "JambaForCausalLM") + "JambaForCausalLM": ("jamba", "JambaForCausalLM"), + "UltravoxModel": ("ultravox", "UltravoxModel"), } _EMBEDDING_MODELS = { diff --git a/vllm/model_executor/models/blip.py b/vllm/model_executor/models/blip.py index a6fd5f58b3cb6..a5eea7c5617cd 100644 --- a/vllm/model_executor/models/blip.py +++ b/vllm/model_executor/models/blip.py @@ -14,8 +14,8 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.multimodal.image import (cached_get_tokenizer, - repeat_and_pad_image_tokens) +from vllm.multimodal.utils import (cached_get_tokenizer, + repeat_and_pad_placeholder_tokens) from vllm.sequence import SequenceData @@ -94,11 +94,11 @@ def input_processor_for_blip( else: image_feature_size = image_feature_size_override - new_prompt, new_token_ids = repeat_and_pad_image_tokens( + new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( tokenizer, llm_inputs.get("prompt"), llm_inputs["prompt_token_ids"], - image_token_id=image_token_id, + placeholder_token_id=image_token_id, repeat_count=image_feature_size, ) diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index 6776b93d126b0..4adb97d837af1 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -29,8 +29,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.image import (cached_get_tokenizer, - repeat_and_pad_image_tokens) +from vllm.multimodal.utils import (cached_get_tokenizer, + repeat_and_pad_placeholder_tokens) from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData from vllm.utils import print_warning_once @@ -120,11 +120,11 @@ def input_processor_for_chameleon(ctx: InputContext, llm_inputs: LLMInputs): model_config = ctx.model_config tokenizer = cached_get_tokenizer(model_config.tokenizer) - new_prompt, new_token_ids = repeat_and_pad_image_tokens( + new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( tokenizer, llm_inputs.get("prompt"), llm_inputs["prompt_token_ids"], - image_token_id=CHAMELEON_IMAGE_TOKEN_ID, + placeholder_token_id=CHAMELEON_IMAGE_TOKEN_ID, repeat_count=CHAMELEON_IMAGE_SEQ_LENGTH, pad_token_left=CHAMELEON_IMAGE_START_TOKEN_ID, pad_token_right=CHAMELEON_IMAGE_END_TOKEN_ID, diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py index fcd360ce8fd72..193e8f7958f84 100644 --- a/vllm/model_executor/models/clip.py +++ b/vllm/model_executor/models/clip.py @@ -15,8 +15,8 @@ RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.multimodal.image import (cached_get_tokenizer, - repeat_and_pad_image_tokens) +from vllm.multimodal.utils import (cached_get_tokenizer, + repeat_and_pad_placeholder_tokens) from vllm.sequence import SequenceData @@ -100,11 +100,11 @@ def input_processor_for_clip( else: image_feature_size = image_feature_size_override - new_prompt, new_token_ids = repeat_and_pad_image_tokens( + new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( tokenizer, llm_inputs.get("prompt"), llm_inputs["prompt_token_ids"], - image_token_id=image_token_id, + placeholder_token_id=image_token_id, repeat_count=image_feature_size, ) diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index e8184e466c5bf..1227fe1a50163 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -35,8 +35,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.base import MultiModalInputs -from vllm.multimodal.image import (cached_get_image_processor, - cached_get_tokenizer) +from vllm.multimodal.image import cached_get_image_processor +from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData from .interfaces import SupportsMultiModal diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index b379c86c1912b..c996f0b73f293 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -23,7 +23,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.base import MultiModalInputs -from vllm.multimodal.image import cached_get_tokenizer +from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import IntermediateTensors, SamplerOutput from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip, diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index 8beb2778fe37a..8cb5065ed79ec 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -16,7 +16,7 @@ from vllm.model_executor.models.gemma import GemmaModel from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.image import cached_get_tokenizer +from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import IntermediateTensors, SamplerOutput from .interfaces import SupportsMultiModal diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 1c8bb8a837c86..4b99b968514a6 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -37,7 +37,7 @@ from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.image import cached_get_tokenizer +from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import IntermediateTensors, SamplerOutput from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip, diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index 4df8c0b54201c..1c2ff5a637fcc 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -23,8 +23,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.multimodal.image import (cached_get_tokenizer, - repeat_and_pad_image_tokens) +from vllm.multimodal.utils import (cached_get_tokenizer, + repeat_and_pad_placeholder_tokens) from vllm.sequence import SequenceData @@ -109,11 +109,11 @@ def input_processor_for_siglip( else: image_feature_size = image_feature_size_override - new_prompt, new_token_ids = repeat_and_pad_image_tokens( + new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( tokenizer, llm_inputs.get("prompt"), llm_inputs["prompt_token_ids"], - image_token_id=image_token_id, + placeholder_token_id=image_token_id, repeat_count=image_feature_size, ) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py new file mode 100644 index 0000000000000..8878bec910f30 --- /dev/null +++ b/vllm/model_executor/models/ultravox.py @@ -0,0 +1,438 @@ +# Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py +"""PyTorch Ultravox model.""" + +import itertools +import math +from typing import Iterable, List, Mapping, Optional, Tuple, Union, cast + +import librosa +import numpy as np +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import functional as F +from transformers.modeling_outputs import BaseModelOutput +from transformers.models.whisper import WhisperFeatureExtractor +from transformers.models.whisper.modeling_whisper import WhisperEncoder + +from vllm.attention import AttentionMetadata +from vllm.config import CacheConfig, MultiModalConfig +from vllm.inputs import INPUT_REGISTRY +from vllm.inputs.data import LLMInputs +from vllm.inputs.registry import InputContext +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.interfaces import SupportsMultiModal +from vllm.model_executor.models.utils import (filter_weights, + init_vllm_registered_model, + merge_multimodal_embeddings) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.base import BatchedTensors, MultiModalInputs +from vllm.multimodal.utils import (cached_get_processor, cached_get_tokenizer, + repeat_and_pad_placeholder_tokens) +from vllm.sequence import SamplerOutput, SequenceData +from vllm.transformers_utils.configs.ultravox import UltravoxConfig + +_AUDIO_PLACEHOLDER_TOKEN = 128002 +_AUDIO_TOKENS_PER_SECOND = 6.25 + +logger = init_logger(__name__) + + +def whisper_feature_extractor(ctx: InputContext) -> WhisperFeatureExtractor: + whisper_processor = cached_get_processor( + ctx.get_hf_config(UltravoxConfig).audio_model_id) + return whisper_processor.feature_extractor + + +def get_ultravox_max_audio_tokens(ctx: InputContext): + feature_extractor = whisper_feature_extractor(ctx) + return math.ceil(feature_extractor.chunk_length * _AUDIO_TOKENS_PER_SECOND) + + +def dummy_data_for_ultravox( + ctx: InputContext, + seq_len: int, + mm_counts: Mapping[str, int], +): + feature_extractor = whisper_feature_extractor(ctx) + + audio_count = mm_counts["audio"] + + audio_token_ids = [_AUDIO_PLACEHOLDER_TOKEN + ] * get_ultravox_max_audio_tokens(ctx) * audio_count + other_token_ids = [0] * (seq_len - len(audio_token_ids)) + + audio_and_sr = (np.array([0.0] * feature_extractor.chunk_length), 1) + mm_dict = { + "audio": + audio_and_sr if audio_count == 1 else [audio_and_sr] * audio_count + } + + return (SequenceData(audio_token_ids + other_token_ids), mm_dict) + + +def input_mapper_for_ultravox(ctx: InputContext, data: object): + if isinstance(data, tuple): + (audio, sr) = cast(Tuple[np.ndarray, Union[float, int]], data) + feature_extractor = whisper_feature_extractor(ctx) + + if sr != feature_extractor.sampling_rate: + audio = librosa.resample(audio, + orig_sr=sr, + target_sr=feature_extractor.sampling_rate) + sr = feature_extractor.sampling_rate + + minimum_audio_length = feature_extractor.n_fft // 2 + 1 + if len(audio) < minimum_audio_length: + # Not enough audio; pad it. + audio = np.pad(audio, (0, minimum_audio_length - len(audio))) + + return MultiModalInputs( + feature_extractor(audio, + sampling_rate=sr, + padding="longest", + return_tensors="pt")) + + raise NotImplementedError(f"Unsupported data type: {type(data)}") + + +def input_processor_for_ultravox(ctx: InputContext, llm_inputs: LLMInputs): + multi_modal_data = llm_inputs.get("multi_modal_data") + if multi_modal_data is None or "audio" not in multi_modal_data: + return llm_inputs + + feature_extractor = whisper_feature_extractor(ctx) + audio_data, sample_rate = multi_modal_data["audio"] + + audio_length = audio_data.shape[0] + if sample_rate != feature_extractor.sampling_rate: + # Account for resampling. + adjustment = feature_extractor.sampling_rate / sample_rate + audio_length = math.ceil(adjustment * audio_length) + + feature_extractor_output_length = math.ceil( + (audio_length - + (feature_extractor.hop_length - 1)) / feature_extractor.hop_length) + + uv_config = ctx.get_hf_config(UltravoxConfig) + audio_num_tokens = min( + max( + 1, + math.ceil(feature_extractor_output_length / + (uv_config.stack_factor * 2))), + get_ultravox_max_audio_tokens(ctx)) + tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer) + + new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( + tokenizer, + llm_inputs.get("prompt"), + llm_inputs["prompt_token_ids"], + placeholder_token_id=_AUDIO_PLACEHOLDER_TOKEN, + repeat_count=audio_num_tokens, + ) + + # NOTE: Create a defensive copy of the original inputs + return LLMInputs(prompt_token_ids=new_token_ids, + prompt=new_prompt, + multi_modal_data=multi_modal_data) + + +class StackAudioFrames(nn.Module): + """ + Stack the audio embedding frames to reduce the sequence length by a factor + of `stack_factor`. + """ + + def __init__(self, stack_factor: int = 8): + super().__init__() + self.stack_factor = stack_factor + + def forward(self, audio_embeds: torch.Tensor) -> torch.Tensor: + B, T, C = audio_embeds.shape + T_pad = (T + self.stack_factor - + 1) // self.stack_factor * self.stack_factor + audio_embeds = F.pad(audio_embeds, (0, 0, 0, T_pad - T)) + B, T, C = audio_embeds.shape + audio_embeds = audio_embeds.view(B, T // self.stack_factor, + C * self.stack_factor) + return audio_embeds + + +class SwiGLU(nn.Module): + + def forward(self, x): + x, gate = x.chunk(2, dim=-1) + return F.silu(gate) * x + + +class UltravoxProjector(nn.Sequential): + + def __init__(self, config: UltravoxConfig): + super().__init__() + self.hidden_dim = config.hidden_size + self._pad_and_stack = StackAudioFrames(config.stack_factor) + dim = config.audio_config.hidden_size * config.stack_factor + self.ln_pre = RMSNorm(dim) + self.linear_1 = nn.Linear(dim, self.hidden_dim, bias=False) + dim = self.hidden_dim + self.act = SwiGLU( + ) if config.projector_act == "swiglu" else get_act_fn( + config.projector_act) + dim = dim // 2 if config.projector_act == "swiglu" else dim + self.linear_2 = nn.Linear(dim, + config.text_config.hidden_size, + bias=False) + self.ln_post = RMSNorm(config.text_config.hidden_size) + + def forward(self, audio_features: torch.Tensor) -> torch.Tensor: + audio_features = self._pad_and_stack(audio_features) + audio_features = self.ln_pre(audio_features) + hidden_states = self.linear_1(audio_features) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + hidden_states = self.ln_post(hidden_states) + return hidden_states + + +class ModifiedWhisperEncoder(WhisperEncoder): + """ + Encoder portion of OpenAI's Whisper model. + + This implementation is a slightly modified version of HF Transformers' + Whisper Encoder, with only a few fixes: + 1. base_model_prefix updated to allow for doing `.from_pretrained` + directly on the encoder + 2. allow less than 30 second of audio padding to be passed in: + - relaxed ValueError check for `input_features` length to be less + than or equal to `expected_seq_length` instead of strictly equal + - embed_pos is now sliced to match the length of `inputs_embeds` + + Original: https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/modeling_whisper.py + See commentary: https://github.com/huggingface/transformers/issues/25744 + """ + + base_model_prefix = "model.encoder" + + def forward( + self, + input_features, + attention_mask=None, + head_mask=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + expected_seq_length = (self.config.max_source_positions * + self.conv1.stride[0] * self.conv2.stride[0]) + if input_features.shape[-1] > expected_seq_length: + raise ValueError( + f"Whisper expects the mel input features to be of length " + f"{expected_seq_length} or less, but found " + f"{input_features.shape[-1]}. Make sure to pad the input mel " + f"features to {expected_seq_length}.") + + output_attentions = (output_attentions if output_attentions is not None + else self.config.output_attentions) + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else + self.config.output_hidden_states) + return_dict = (return_dict if return_dict is not None else + self.config.use_return_dict) + inputs_embeds = nn.functional.gelu(self.conv1(input_features)) + inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds)) + + inputs_embeds = inputs_embeds.permute(0, 2, 1) + embed_pos = self.embed_positions.weight[:inputs_embeds.size(-2)] + + hidden_states = inputs_embeds + embed_pos + hidden_states = nn.functional.dropout(hidden_states, + p=self.dropout, + training=self.training) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + assert head_mask.size()[0] == (len(self.layers)), ( + f"The head_mask should be specified for {len(self.layers)} " + f"layers, but it is for {head_mask.size()[0]}.") + + for idx, encoder_layer in enumerate(self.layers): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states, ) + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # noqa: E501 + to_drop = False + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: # skip the layer + to_drop = True + + if to_drop: + layer_outputs = (None, None) + else: + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + encoder_layer.__call__, + hidden_states, + None, + (head_mask[idx] if head_mask is not None else None), + output_attentions, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + None, + layer_head_mask=(head_mask[idx] + if head_mask is not None else None), + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1], ) + + hidden_states = self.layer_norm(hidden_states) + if output_hidden_states: + encoder_states = encoder_states + (hidden_states, ) + + if not return_dict: + return tuple( + v for v in [hidden_states, encoder_states, all_attentions] + if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=encoder_states, + attentions=all_attentions, + ) + + +@MULTIMODAL_REGISTRY.register_input_mapper("audio", input_mapper_for_ultravox) +@MULTIMODAL_REGISTRY.register_max_multimodal_tokens( + "audio", get_ultravox_max_audio_tokens) +@INPUT_REGISTRY.register_dummy_data(dummy_data_for_ultravox) +@INPUT_REGISTRY.register_input_processor(input_processor_for_ultravox) +class UltravoxModel(nn.Module, SupportsMultiModal): + + def __init__(self, + config: UltravoxConfig, + multimodal_config: MultiModalConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional["QuantizationConfig"] = None): + super().__init__() + self.config = config + self.multi_modal_config = multimodal_config + assert self.multi_modal_config + + if config.audio_model_id is not None: + self.audio_tower = ModifiedWhisperEncoder.from_pretrained( + config.audio_model_id) + else: + self.audio_tower = ModifiedWhisperEncoder(config.audio_config) + self.multi_modal_projector = UltravoxProjector(config) + self.language_model = init_vllm_registered_model( + config.text_config, cache_config, quant_config) + + def _audio_features_to_embeddings(self, input_features: torch.Tensor, + dtype: torch.dtype) -> torch.Tensor: + audio_input = input_features.to(self.audio_tower.dtype) + audio_features = self.audio_tower(audio_input).last_hidden_state + audio_features = audio_features.to(self.audio_tower.dtype) + audio_embeddings = self.multi_modal_projector(audio_features).to(dtype) + return audio_embeddings + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[torch.Tensor], + *, + input_features: Optional[BatchedTensors] = None, + ) -> SamplerOutput: + """Run forward pass for Ultravox + + One key thing to understand is the `input_ids` already accounts for the + positions of the to-be-inserted audio embeddings. The to-be-inserted + audio has a size that is essentially 6.25 tokens per second of audio. + + This way, the `positions` and `attn_metadata` are consistent + with the `input_ids`. + + Args: + input_features: A batch of audio inputs, [1, 80, M]. + """ + if input_features is not None: + inputs_embeds = self.language_model.model.get_input_embeddings( + input_ids) + if isinstance(input_features, list): + # TODO: Batch these through the encoder/projector instead of + # serializing them. + audio_embeddings = [ + self._audio_features_to_embeddings( + single_features.unsqueeze(0), + inputs_embeds.dtype).squeeze(0) + for single_features in input_features + ] + elif isinstance(input_features, torch.Tensor): + audio_embeddings = self._audio_features_to_embeddings( + input_features, inputs_embeds.dtype) + else: + raise ValueError( + "The input audio features should be a tensor or a list " + f"of tensors, not {type(input_features)}") + + merge_multimodal_embeddings(input_ids, inputs_embeds, + audio_embeddings, + _AUDIO_PLACEHOLDER_TOKEN) + input_ids = None + else: + inputs_embeds = None + + hidden_states = self.language_model.model( + input_ids=input_ids, + positions=positions, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds) + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + return self.language_model.compute_logits(hidden_states, + sampling_metadata) + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + return self.language_model.sample(logits, sampling_metadata) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + # prepare weight iterators for components + projector_weights, llm_weights = itertools.tee(weights, 2) + + # load projector weights + projector_weights = filter_weights(projector_weights, + "multi_modal_projector") + projector_params_dict = dict( + self.multi_modal_projector.named_parameters()) + for name, loaded_weight in projector_weights: + param = projector_params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + # load llm backbone + llm_weights = filter_weights(llm_weights, "language_model") + self.language_model.load_weights(llm_weights) diff --git a/vllm/multimodal/image.py b/vllm/multimodal/image.py index 916bd5e601bb7..6cdde949bc2b1 100644 --- a/vllm/multimodal/image.py +++ b/vllm/multimodal/image.py @@ -1,5 +1,4 @@ from functools import lru_cache -from typing import List, Optional, Tuple, TypeVar import torch from PIL import Image @@ -8,7 +7,6 @@ from vllm.inputs.registry import InputContext from vllm.logger import init_logger from vllm.transformers_utils.image_processor import get_image_processor -from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer from vllm.utils import is_list_of from .base import MultiModalData, MultiModalInputs, MultiModalPlugin @@ -16,87 +14,6 @@ logger = init_logger(__name__) cached_get_image_processor = lru_cache(get_image_processor) -cached_get_tokenizer = lru_cache(get_tokenizer) - -# Utilities for image input processors -_T = TypeVar("_T", str, int) - - -def repeat_and_pad_token( - token: _T, - *, - repeat_count: int = 1, - pad_token_left: Optional[_T] = None, - pad_token_right: Optional[_T] = None, -) -> List[_T]: - replacement = [token] * repeat_count - if pad_token_left is not None: - replacement = [pad_token_left] + replacement - if pad_token_right is not None: - replacement = replacement + [pad_token_right] - - return replacement - - -def repeat_and_pad_image_tokens( - tokenizer: AnyTokenizer, - prompt: Optional[str], - prompt_token_ids: List[int], - *, - image_token_id: int, - repeat_count: int = 1, - pad_token_left: Optional[int] = None, - pad_token_right: Optional[int] = None, -) -> Tuple[Optional[str], List[int]]: - if prompt is None: - new_prompt = None - else: - image_token_str = tokenizer.decode(image_token_id) - pad_token_str_left = (None if pad_token_left is None else - tokenizer.decode(pad_token_left)) - pad_token_str_right = (None if pad_token_right is None else - tokenizer.decode(pad_token_right)) - replacement_str = "".join( - repeat_and_pad_token( - image_token_str, - repeat_count=repeat_count, - pad_token_left=pad_token_str_left, - pad_token_right=pad_token_str_right, - )) - - image_token_count = prompt.count(image_token_str) - # This is an arbitrary number to distinguish between the two cases - if image_token_count > 16: - logger.warning( - "Please follow the prompt format that is " - "documented on HuggingFace which does not involve " - "repeating %s tokens.", image_token_str) - elif image_token_count > 1: - logger.warning("Multiple image input is not supported yet, " - "so any extra image tokens will be treated " - "as plain text.") - - # The image tokens are removed to be consistent with HuggingFace - new_prompt = prompt.replace(image_token_str, replacement_str, 1) - - new_token_ids: List[int] = [] - for i, token in enumerate(prompt_token_ids): - if token == image_token_id: - replacement_ids = repeat_and_pad_token( - image_token_id, - repeat_count=repeat_count, - pad_token_left=pad_token_left, - pad_token_right=pad_token_right, - ) - new_token_ids.extend(replacement_ids) - - # No need to further scan the list since we only replace once - new_token_ids.extend(prompt_token_ids[i + 1:]) - break - else: - new_token_ids.append(token) - - return new_prompt, new_token_ids class ImagePlugin(MultiModalPlugin): diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index d1e624cdb8ace..c1317e8a671e7 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -1,6 +1,7 @@ import base64 +from functools import lru_cache from io import BytesIO -from typing import Tuple, Union +from typing import List, Optional, Tuple, TypeVar, Union import librosa import numpy as np @@ -9,7 +10,15 @@ from vllm.connections import global_http_connection from vllm.envs import VLLM_AUDIO_FETCH_TIMEOUT, VLLM_IMAGE_FETCH_TIMEOUT +from vllm.logger import init_logger from vllm.multimodal.base import MultiModalDataDict +from vllm.transformers_utils.processor import get_processor +from vllm.transformers_utils.tokenizer import AnyTokenizer, get_tokenizer + +logger = init_logger(__name__) + +cached_get_tokenizer = lru_cache(get_tokenizer) +cached_get_processor = lru_cache(get_processor) def _load_image_from_bytes(b: bytes): @@ -154,3 +163,84 @@ def rescale_image_size(image: Image.Image, if transpose >= 0: image = image.transpose(Image.Transpose(transpose)) return image + + +# Utilities for input processors +_T = TypeVar("_T", str, int) + + +def repeat_and_pad_token( + token: _T, + *, + repeat_count: int = 1, + pad_token_left: Optional[_T] = None, + pad_token_right: Optional[_T] = None, +) -> List[_T]: + replacement = [token] * repeat_count + if pad_token_left is not None: + replacement = [pad_token_left] + replacement + if pad_token_right is not None: + replacement = replacement + [pad_token_right] + + return replacement + + +def repeat_and_pad_placeholder_tokens( + tokenizer: AnyTokenizer, + prompt: Optional[str], + prompt_token_ids: List[int], + *, + placeholder_token_id: int, + repeat_count: int = 1, + pad_token_left: Optional[int] = None, + pad_token_right: Optional[int] = None, +) -> Tuple[Optional[str], List[int]]: + if prompt is None: + new_prompt = None + else: + placeholder_token_str = tokenizer.decode(placeholder_token_id) + pad_token_str_left = (None if pad_token_left is None else + tokenizer.decode(pad_token_left)) + pad_token_str_right = (None if pad_token_right is None else + tokenizer.decode(pad_token_right)) + replacement_str = "".join( + repeat_and_pad_token( + placeholder_token_str, + repeat_count=repeat_count, + pad_token_left=pad_token_str_left, + pad_token_right=pad_token_str_right, + )) + + placeholder_token_count = prompt.count(placeholder_token_str) + # This is an arbitrary number to distinguish between the two cases + if placeholder_token_count > 16: + logger.warning( + "Please follow the prompt format that is " + "documented on HuggingFace which does not involve " + "repeating %s tokens.", placeholder_token_str) + elif placeholder_token_count > 1: + logger.warning("Multiple multi-modal input is not supported yet, " + "so any extra placeholder tokens will be treated " + "as plain text.") + + # The image tokens are removed to be consistent with HuggingFace + new_prompt = prompt.replace(placeholder_token_str, replacement_str, 1) + + new_token_ids: List[int] = [] + for i, token in enumerate(prompt_token_ids): + if token == placeholder_token_id: + replacement_ids = repeat_and_pad_token( + placeholder_token_id, + repeat_count=repeat_count, + pad_token_left=pad_token_left, + pad_token_right=pad_token_right, + ) + new_token_ids.extend(replacement_ids) + + # No need to further scan the list since we only replace once + new_token_ids.extend(prompt_token_ids[i + 1:]) + break + else: + new_token_ids.append(token) + + return new_prompt, new_token_ids diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 5f04b39ef524e..d3024965c0b4c 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -12,7 +12,7 @@ InternVLChatConfig, JAISConfig, MedusaConfig, MLPSpeculatorConfig, MPTConfig, NemotronConfig, - RWConfig) + RWConfig, UltravoxConfig) if VLLM_USE_MODELSCOPE: from modelscope import AutoConfig @@ -32,6 +32,7 @@ "medusa": MedusaConfig, "internvl_chat": InternVLChatConfig, "nemotron": NemotronConfig, + "ultravox": UltravoxConfig, } for name, cls in _CONFIG_REGISTRY.items(): diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 5ccacd4a4c40a..22b906a3149ec 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -10,6 +10,7 @@ from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig from vllm.transformers_utils.configs.mpt import MPTConfig from vllm.transformers_utils.configs.nemotron import NemotronConfig +from vllm.transformers_utils.configs.ultravox import UltravoxConfig __all__ = [ "ChatGLMConfig", @@ -21,4 +22,5 @@ "MedusaConfig", "MLPSpeculatorConfig", "NemotronConfig", + "UltravoxConfig", ] diff --git a/vllm/transformers_utils/configs/ultravox.py b/vllm/transformers_utils/configs/ultravox.py new file mode 100644 index 0000000000000..f724bf7f2f1cd --- /dev/null +++ b/vllm/transformers_utils/configs/ultravox.py @@ -0,0 +1,99 @@ +# Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_config.py +from typing import Any, Dict, Optional + +import transformers + + +class UltravoxConfig(transformers.PretrainedConfig): + r""" + This is the configuration class to store the configuration of a + [`UltravoxForConditionalGeneration`]. It is used to instantiate an + Ultravox model according to the specified arguments, defining the model + architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to + control the model outputs. Read the documentation from [`PretrainedConfig`] + for more information. + + Args: + audio_config (`Union[AutoConfig, dict]`, *optional*): + Custom audio config or dict + text_config (`Union[AutoConfig, dict]`, *optional*): + The config object of the text backbone. Can be any of `LlamaConfig` + or `MistralConfig`. + ignore_index (`int`, *optional*, defaults to -100): + The ignore index for the loss function. + audio_token_index (`int`, *optional*, defaults to 32000): + The audio token index to encode the audio prompt. + stack_factor (`int`, *optional*, defaults to 8): + Audio downsampling factor for the multimodal projector. + norm_init (`float`, *optional*, defaults to 0.4): + The initialization value for the layer normalization. + projector_act (`str`, *optional*, defaults to `"swiglu"`): + The activation function used by the multimodal projector. + text_model_lora_config (`LoraConfigSimplified`, *optional*): + The LoRA configuration for finetuning the text model. + audio_model_lora_config (`LoraConfigSimplified`, *optional*): + The LoRA configuration for finetuning the audio model. + """ + + model_type = "ultravox" + is_composition = False + + def __init__( + self, + audio_config: Optional[Dict[str, Any]] = None, + text_config: Optional[Dict[str, Any]] = None, + audio_model_id: Optional[str] = None, + text_model_id: Optional[str] = None, + ignore_index: int = -100, + audio_token_index: int = 32000, + hidden_size: int = 4096, + stack_factor: int = 8, + norm_init: float = 0.4, + projector_act: str = "swiglu", + text_model_lora_config: Optional[Dict[str, Any]] = None, + audio_model_lora_config: Optional[Dict[str, Any]] = None, + **kwargs, + ): + self.ignore_index = ignore_index + + self.audio_model_id = audio_model_id + self.text_model_id = text_model_id + self.audio_token_index = audio_token_index + + self.hidden_size = hidden_size + self.stack_factor = stack_factor + self.norm_init = norm_init + self.projector_act = projector_act + + if text_model_id is not None: + # Avoid circular import + from vllm.transformers_utils.config import get_config + + self.text_config = get_config(text_model_id, + trust_remote_code=False) + else: + text_config = text_config or {} + self.text_config = transformers.CONFIG_MAPPING[text_config.get( + "model_type", "llama")](**text_config) + + if audio_model_id is not None: + # Avoid circular import + from vllm.transformers_utils.config import get_config + + self.audio_config = get_config(audio_model_id, + trust_remote_code=False) + else: + audio_config = audio_config or {} + self.audio_config = transformers.CONFIG_MAPPING[audio_config.get( + "model_type", "whisper")](**audio_config) + + self.text_model_lora_config = text_model_lora_config or {} + self.audio_model_lora_config = audio_model_lora_config or {} + + self.vocab_size = self.text_config.vocab_size + + self.initializer_range = self.text_config.initializer_range + + super().__init__(**kwargs) diff --git a/vllm/transformers_utils/processor.py b/vllm/transformers_utils/processor.py new file mode 100644 index 0000000000000..7b7778e0eb876 --- /dev/null +++ b/vllm/transformers_utils/processor.py @@ -0,0 +1,33 @@ +def get_processor( + processor_name: str, + *args, + trust_remote_code: bool = False, + **kwargs, +): + """Gets a processor for the given model name via HuggingFace.""" + # don't put this import at the top level + # it will call torch.cuda.device_count() + from transformers import AutoProcessor + + try: + processor = AutoProcessor.from_pretrained( + processor_name, + *args, + trust_remote_code=trust_remote_code, + **kwargs) + except ValueError as e: + # If the error pertains to the processor class not existing or not + # currently being imported, suggest using the --trust-remote-code flag. + # Unlike AutoTokenizer, AutoProcessor does not separate such errors + if not trust_remote_code: + err_msg = ( + "Failed to load the processor. If the processor is " + "a custom processor not yet available in the HuggingFace " + "transformers library, consider setting " + "`trust_remote_code=True` in LLM or using the " + "`--trust-remote-code` flag in the CLI.") + raise RuntimeError(err_msg) from e + else: + raise e + + return processor