Skip to content

Commit

Permalink
Add UltravoxModel and UltravoxConfig
Browse files Browse the repository at this point in the history
  • Loading branch information
petersalas committed Aug 16, 2024
1 parent 37fd47e commit 3054191
Show file tree
Hide file tree
Showing 18 changed files with 713 additions and 229 deletions.
141 changes: 21 additions & 120 deletions tests/entrypoints/openai/test_audio.py
Original file line number Diff line number Diff line change
@@ -1,134 +1,35 @@
import math
import sys
import time
from typing import Dict, List, Optional, Tuple, Union, cast
from unittest.mock import patch

import librosa
import numpy as np
from typing import Dict, List

import openai
import pytest
import requests
import torch

from vllm import ModelRegistry
from vllm.config import MultiModalConfig
from vllm.inputs import INPUT_REGISTRY
from vllm.inputs.data import LLMInputs
from vllm.inputs.registry import InputContext
from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.model_executor.models.opt import OPTForCausalLM
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs
from vllm.multimodal.image import (cached_get_tokenizer,
repeat_and_pad_image_tokens)
from vllm.multimodal.utils import encode_audio_base64, fetch_audio
from vllm.utils import get_open_port

from ...utils import VLLM_PATH
from vllm.multimodal.utils import encode_audio_base64, fetch_audio

chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja"
assert chatml_jinja_path.exists()
from ...utils import RemoteOpenAIServer

MODEL_NAME = "facebook/opt-125m"
MODEL_NAME = "fixie-ai/ultravox-v0_3"
TEST_AUDIO_URLS = [
"https://upload.wikimedia.org/wikipedia/en/b/bf/Dave_Niehaus_Winning_Call_1995_AL_Division_Series.ogg",
]


def server_function(port):

def fake_input_mapper(ctx: InputContext, data: object):
assert isinstance(data, tuple)
(audio, sr) = cast(Tuple[np.ndarray, Union[float, int]], data)

# Resample it to 1 sample per second
audio = librosa.resample(audio, orig_sr=sr, target_sr=1)
return MultiModalInputs({"processed_audio": torch.from_numpy(audio)})

def fake_input_processor(ctx: InputContext, llm_inputs: LLMInputs):
multi_modal_data = llm_inputs.get("multi_modal_data")
if multi_modal_data is None or "audio" not in multi_modal_data:
return llm_inputs

audio, sr = multi_modal_data.get("audio")
audio_duration = math.ceil(len(audio) / sr)

new_prompt, new_token_ids = repeat_and_pad_image_tokens(
cached_get_tokenizer(ctx.model_config.tokenizer),
llm_inputs.get("prompt"),
llm_inputs["prompt_token_ids"],
image_token_id=62, # "_"
repeat_count=audio_duration)

return LLMInputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data)

@MULTIMODAL_REGISTRY.register_input_mapper("audio", fake_input_mapper)
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
"audio", lambda *_, **__: 100)
@INPUT_REGISTRY.register_input_processor(fake_input_processor)
class FakeAudioModel(OPTForCausalLM, SupportsMultiModal):

def __init__(self, *args, multimodal_config: MultiModalConfig,
**kwargs):
assert multimodal_config is not None
super().__init__(*args, **kwargs)

def forward(
self,
*args,
processed_audio: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
return super().forward(*args, **kwargs)

ModelRegistry.register_model("OPTForCausalLM", FakeAudioModel)

with patch("vllm.entrypoints.chat_utils._mm_token_str",
lambda *_, **__: "_"):
sys.argv = ["placeholder.py"] + \
(f"--model {MODEL_NAME} --gpu-memory-utilization 0.10 "
"--dtype bfloat16 --enforce-eager --api-key token-abc123 "
f"--port {port} --chat-template {chatml_jinja_path} "
"--disable-frontend-multiprocessing").split()
import runpy
runpy.run_module('vllm.entrypoints.openai.api_server',
run_name='__main__')
@pytest.fixture(scope="module")
def server():
args = [
"--dtype",
"bfloat16",
"--max-model-len",
"4096",
"--enforce-eager",
]

with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server


@pytest.fixture(scope="module")
def client():
port = get_open_port()
ctx = torch.multiprocessing.get_context("spawn")
server = ctx.Process(target=server_function, args=(port, ))
server.start()
MAX_SERVER_START_WAIT_S = 60
client = openai.AsyncOpenAI(
base_url=f"http://localhost:{port}/v1",
api_key="token-abc123",
)
# run health check
health_url = f"http://localhost:{port}/health"
start = time.time()
while True:
try:
if requests.get(health_url).status_code == 200:
break
except Exception as err:
result = server.exitcode
if result is not None:
raise RuntimeError("Server exited unexpectedly.") from err

time.sleep(0.5)
if time.time() - start > MAX_SERVER_START_WAIT_S:
raise RuntimeError("Server failed to start in time.") from err

try:
yield client
finally:
server.kill()
def client(server):
return server.get_async_client()


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -172,7 +73,7 @@ async def test_single_chat_session_audio(client: openai.AsyncOpenAI,
choice = chat_completion.choices[0]
assert choice.finish_reason == "length"
assert chat_completion.usage == openai.types.CompletionUsage(
completion_tokens=10, prompt_tokens=36, total_tokens=46)
completion_tokens=10, prompt_tokens=202, total_tokens=212)

message = choice.message
message = chat_completion.choices[0].message
Expand Down Expand Up @@ -227,7 +128,7 @@ async def test_single_chat_session_audio_base64encoded(
choice = chat_completion.choices[0]
assert choice.finish_reason == "length"
assert chat_completion.usage == openai.types.CompletionUsage(
completion_tokens=10, prompt_tokens=36, total_tokens=46)
completion_tokens=10, prompt_tokens=202, total_tokens=212)

message = choice.message
message = chat_completion.choices[0].message
Expand Down
6 changes: 4 additions & 2 deletions vllm/entrypoints/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,8 @@ def _mm_token_str(model_config: ModelConfig, tokenizer: PreTrainedTokenizer,
modality: Literal["image", "audio"]) -> Optional[str]:
# TODO: Let user specify how to insert image tokens into prompt
# (similar to chat template)
model_type = model_config.hf_config.model_type
if modality == "image":
model_type = model_config.hf_config.model_type
if model_type == "phi3_v":
# Workaround since this token is not defined in the tokenizer
return "<|image_1|>"
Expand All @@ -135,7 +135,9 @@ def _mm_token_str(model_config: ModelConfig, tokenizer: PreTrainedTokenizer,

raise TypeError(f"Unknown model type: {model_type}")
elif modality == "audio":
raise TypeError("No audio models are supported yet.")
if model_type == "ultravox":
return "<|reserved_special_token_0|>"
raise TypeError(f"Unknown model type: {model_type}")
else:
raise TypeError(f"Unknown modality: {modality}")

Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@
"Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
"MedusaModel": ("medusa", "Medusa"),
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
"JambaForCausalLM": ("jamba", "JambaForCausalLM")
"JambaForCausalLM": ("jamba", "JambaForCausalLM"),
"UltravoxModel": ("ultravox", "UltravoxModel"),
}

_EMBEDDING_MODELS = {
Expand Down
8 changes: 4 additions & 4 deletions vllm/model_executor/models/blip.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.multimodal.image import (cached_get_tokenizer,
repeat_and_pad_image_tokens)
from vllm.multimodal.utils import (cached_get_tokenizer,
repeat_and_pad_placeholder_tokens)
from vllm.sequence import SequenceData


Expand Down Expand Up @@ -94,11 +94,11 @@ def input_processor_for_blip(
else:
image_feature_size = image_feature_size_override

new_prompt, new_token_ids = repeat_and_pad_image_tokens(
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
tokenizer,
llm_inputs.get("prompt"),
llm_inputs["prompt_token_ids"],
image_token_id=image_token_id,
placeholder_token_id=image_token_id,
repeat_count=image_feature_size,
)

Expand Down
8 changes: 4 additions & 4 deletions vllm/model_executor/models/chameleon.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import (cached_get_tokenizer,
repeat_and_pad_image_tokens)
from vllm.multimodal.utils import (cached_get_tokenizer,
repeat_and_pad_placeholder_tokens)
from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData
from vllm.utils import print_warning_once

Expand Down Expand Up @@ -120,11 +120,11 @@ def input_processor_for_chameleon(ctx: InputContext, llm_inputs: LLMInputs):

model_config = ctx.model_config
tokenizer = cached_get_tokenizer(model_config.tokenizer)
new_prompt, new_token_ids = repeat_and_pad_image_tokens(
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
tokenizer,
llm_inputs.get("prompt"),
llm_inputs["prompt_token_ids"],
image_token_id=CHAMELEON_IMAGE_TOKEN_ID,
placeholder_token_id=CHAMELEON_IMAGE_TOKEN_ID,
repeat_count=CHAMELEON_IMAGE_SEQ_LENGTH,
pad_token_left=CHAMELEON_IMAGE_START_TOKEN_ID,
pad_token_right=CHAMELEON_IMAGE_END_TOKEN_ID,
Expand Down
8 changes: 4 additions & 4 deletions vllm/model_executor/models/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal.image import (cached_get_tokenizer,
repeat_and_pad_image_tokens)
from vllm.multimodal.utils import (cached_get_tokenizer,
repeat_and_pad_placeholder_tokens)
from vllm.sequence import SequenceData


Expand Down Expand Up @@ -100,11 +100,11 @@ def input_processor_for_clip(
else:
image_feature_size = image_feature_size_override

new_prompt, new_token_ids = repeat_and_pad_image_tokens(
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
tokenizer,
llm_inputs.get("prompt"),
llm_inputs["prompt_token_ids"],
image_token_id=image_token_id,
placeholder_token_id=image_token_id,
repeat_count=image_feature_size,
)

Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/models/fuyu.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs
from vllm.multimodal.image import (cached_get_image_processor,
cached_get_tokenizer)
from vllm.multimodal.image import cached_get_image_processor
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import IntermediateTensors, SamplerOutput, SequenceData

from .interfaces import SupportsMultiModal
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.base import MultiModalInputs
from vllm.multimodal.image import cached_get_tokenizer
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import IntermediateTensors, SamplerOutput

from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/paligemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from vllm.model_executor.models.gemma import GemmaModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import cached_get_tokenizer
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import IntermediateTensors, SamplerOutput

from .interfaces import SupportsMultiModal
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/phi3v.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import cached_get_tokenizer
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import IntermediateTensors, SamplerOutput

from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
Expand Down
8 changes: 4 additions & 4 deletions vllm/model_executor/models/siglip.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal.image import (cached_get_tokenizer,
repeat_and_pad_image_tokens)
from vllm.multimodal.utils import (cached_get_tokenizer,
repeat_and_pad_placeholder_tokens)
from vllm.sequence import SequenceData


Expand Down Expand Up @@ -109,11 +109,11 @@ def input_processor_for_siglip(
else:
image_feature_size = image_feature_size_override

new_prompt, new_token_ids = repeat_and_pad_image_tokens(
new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
tokenizer,
llm_inputs.get("prompt"),
llm_inputs["prompt_token_ids"],
image_token_id=image_token_id,
placeholder_token_id=image_token_id,
repeat_count=image_feature_size,
)

Expand Down
Loading

0 comments on commit 3054191

Please sign in to comment.