Skip to content

Commit 0bb8c84

Browse files
alex-jw-brookswuisawesome
authored andcommitted
[Model] Add Granite Speech Support (vllm-project#16246)
Signed-off-by: Alex-Brooks <Alex.brooks@ibm.com> Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
1 parent 3c8c822 commit 0bb8c84

File tree

11 files changed

+1025
-28
lines changed

11 files changed

+1025
-28
lines changed

docs/source/models/supported_models.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -895,6 +895,13 @@ See [this page](#generative-models) for more information on how to use generativ
895895
* ✅︎
896896
* ✅︎
897897
* ✅︎
898+
- * `GraniteSpeechForConditionalGeneration`
899+
* Granite Speech
900+
* T + A
901+
* `ibm-granite/granite-speech-3.3-8b`
902+
* ✅︎
903+
* ✅︎
904+
* ✅︎
898905
- * `H2OVLChatModel`
899906
* H2OVL
900907
* T + I<sup>E+</sup>

examples/offline_inference/audio_language.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,37 @@ class ModelRequestData(NamedTuple):
3838
# Unless specified, these settings have been tested to work on a single L4.
3939

4040

41+
# Granite Speech
42+
def run_granite_speech(question: str, audio_count: int) -> ModelRequestData:
43+
# NOTE - the setting in this example are somehat different than what is
44+
# optimal for granite speech, and it is generally recommended to use beam
45+
# search. Check the model README for suggested settings.
46+
# https://huggingface.co/ibm-granite/granite-speech-3.3-8b
47+
model_name = "ibm-granite/granite-speech-3.3-8b"
48+
49+
engine_args = EngineArgs(
50+
model=model_name,
51+
trust_remote_code=True,
52+
max_model_len=2048,
53+
max_num_seqs=2,
54+
enable_lora=True,
55+
max_lora_rank=64,
56+
limit_mm_per_prompt={"audio": audio_count},
57+
)
58+
59+
# The model has an audio-specific lora directly in its model dir;
60+
# it should be enabled whenever you pass audio inputs to the model.
61+
speech_lora_path = model_name
62+
audio_placeholder = "<|audio|>" * audio_count
63+
prompts = f"<|start_of_role|>system<|end_of_role|>Knowledge Cutoff Date: April 2024.\nToday's Date: December 19, 2024.\nYou are Granite, developed by IBM. You are a helpful AI assistant<|end_of_text|>\n<|start_of_role|>user<|end_of_role|>{audio_placeholder}{question}<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>" # noqa: E501
64+
65+
return ModelRequestData(
66+
engine_args=engine_args,
67+
prompt=prompts,
68+
lora_requests=[LoRARequest("speech", 1, speech_lora_path)],
69+
)
70+
71+
4172
# MiniCPM-O
4273
def run_minicpmo(question: str, audio_count: int) -> ModelRequestData:
4374
model_name = "openbmb/MiniCPM-o-2_6"
@@ -209,6 +240,7 @@ def run_whisper(question: str, audio_count: int) -> ModelRequestData:
209240

210241

211242
model_example_map = {
243+
"granite_speech": run_granite_speech,
212244
"minicpmo": run_minicpmo,
213245
"phi4_mm": run_phi4mm,
214246
"qwen2_audio": run_qwen2_audio,

tests/conftest.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from tests.models.utils import (TokensTextLogprobs,
2222
TokensTextLogprobsPromptLogprobs)
2323
from vllm import LLM, SamplingParams
24+
from vllm.assets.audio import AudioAsset
2425
from vllm.assets.image import ImageAsset
2526
from vllm.assets.video import VideoAsset
2627
from vllm.config import TaskOption, _get_and_verify_dtype
@@ -103,10 +104,25 @@ def prompts(self, prompts: _VideoAssetPrompts) -> list[str]:
103104
return [prompts["sample_demo_1"]]
104105

105106

107+
class _AudioAssetsBase(UserList[AudioAsset]):
108+
pass
109+
110+
111+
class _AudioAssets(_AudioAssetsBase):
112+
113+
def __init__(self) -> None:
114+
super().__init__([
115+
AudioAsset("mary_had_lamb"),
116+
AudioAsset("winning_call"),
117+
])
118+
119+
106120
IMAGE_ASSETS = _ImageAssets()
107121
"""Singleton instance of :class:`_ImageAssets`."""
108122
VIDEO_ASSETS = _VideoAssets()
109123
"""Singleton instance of :class:`_VideoAssets`."""
124+
AUDIO_ASSETS = _AudioAssets()
125+
"""Singleton instance of :class:`_AudioAssets`."""
110126

111127

112128
@pytest.fixture(scope="function", autouse=True)
@@ -263,6 +279,11 @@ def video_assets() -> _VideoAssets:
263279
return VIDEO_ASSETS
264280

265281

282+
@pytest.fixture(scope="session")
283+
def audio_assets() -> _AudioAssets:
284+
return AUDIO_ASSETS
285+
286+
266287
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature, dict)
267288
_R = TypeVar("_R")
268289

@@ -390,10 +411,15 @@ def get_inputs(
390411
processor_kwargs["images"] = image
391412
if videos is not None and (video := videos[i]) is not None:
392413
processor_kwargs["videos"] = video
393-
if audios is not None and (audio_tuple := audios[i]) is not None:
394-
audio, sr = audio_tuple
395-
processor_kwargs["audio"] = audio
396-
processor_kwargs["sampling_rate"] = sr
414+
if audios is not None and (audio_inputs := audios[i]) is not None:
415+
# HACK - not all processors take sampling_rate; we should
416+
# clean this up in the future.
417+
if len(audio_inputs) == 2:
418+
audio, sr = audio_inputs
419+
processor_kwargs["audio"] = audio
420+
processor_kwargs["sampling_rate"] = sr
421+
else:
422+
processor_kwargs["audio"] = audio_inputs
397423

398424
inputs = self.processor(**processor_kwargs)
399425
if isinstance(inputs, BatchFeature):
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
from collections.abc import Sequence
4+
from typing import Optional
5+
6+
import pytest
7+
from transformers import AutoModelForSpeechSeq2Seq
8+
9+
from vllm.lora.request import LoRARequest
10+
from vllm.sequence import SampleLogprobs
11+
12+
from ....conftest import HfRunner, PromptAudioInput, VllmRunner, _AudioAssets
13+
from ...registry import HF_EXAMPLE_MODELS
14+
from ...utils import check_logprobs_close
15+
16+
HF_AUDIO_PROMPT = "<|start_of_role|>system<|end_of_role|>Knowledge Cutoff Date: April 2024.\nToday's Date: December 19, 2024.\nYou are Granite, developed by IBM. You are a helpful AI assistant<|end_of_text|>\n<|start_of_role|>user<|end_of_role|><|audio|>can you transcribe the speech into a written format?<|end_of_text|>\n<|start_of_role|>assistant<|end_of_role|>" # noqa: E501
17+
18+
19+
def vllm_to_hf_output(
20+
vllm_output: tuple[list[int], str, Optional[SampleLogprobs]],
21+
) -> tuple[list[int], str, Optional[SampleLogprobs]]:
22+
"""Sanitize hf output to be comparable with vllm output."""
23+
output_ids, output_str, out_logprobs = vllm_output
24+
25+
hf_output_str = output_str + "<|end_of_text|>"
26+
27+
return output_ids, hf_output_str, out_logprobs
28+
29+
30+
MODEL_NAME = "ibm-granite/granite-speech-3.3-8b"
31+
# Audio lora co-exists directly in the model directory, but
32+
# currently still needs to be passed directly to vLLM.
33+
audio_lora_path = MODEL_NAME
34+
models = [MODEL_NAME]
35+
36+
37+
def run_test(
38+
hf_runner: type[HfRunner],
39+
vllm_runner: type[VllmRunner],
40+
inputs: Sequence[tuple[list[str], PromptAudioInput]],
41+
model: str,
42+
*,
43+
max_model_len: int,
44+
dtype: str,
45+
max_tokens: int,
46+
num_logprobs: int,
47+
tensor_parallel_size: int,
48+
distributed_executor_backend: Optional[str] = None,
49+
):
50+
"""Inference result should be the same between hf and vllm.
51+
52+
All the audio fixtures for the test are from AUDIO_ASSETS.
53+
For huggingface runner, we provide the audio as input.
54+
For vllm runner, we provide MultiModalDataDict objects
55+
and corresponding MultiModalConfig as input.
56+
Note, the text input is also adjusted to abide by vllm contract.
57+
The text output is sanitized to be able to compare with hf.
58+
"""
59+
# NOTE: take care of the order. run vLLM first, and then run HF.
60+
# vLLM needs a fresh new process without cuda initialization.
61+
# if we run HF first, the cuda initialization will be done and it
62+
# will hurt multiprocessing backend with fork method (the default method).
63+
# max_model_len should be greater than image_feature_size
64+
with vllm_runner(
65+
model,
66+
task="generate",
67+
max_model_len=max_model_len,
68+
max_num_seqs=1,
69+
dtype=dtype,
70+
limit_mm_per_prompt={"audio": 1},
71+
tensor_parallel_size=tensor_parallel_size,
72+
distributed_executor_backend=distributed_executor_backend,
73+
enable_lora=True,
74+
max_lora_rank=64,
75+
enforce_eager=True,
76+
) as vllm_model:
77+
lora_request = LoRARequest("audio", 1, audio_lora_path)
78+
vllm_outputs_per_case = [
79+
vllm_model.generate_greedy_logprobs(prompts,
80+
max_tokens,
81+
num_logprobs=num_logprobs,
82+
audios=audios,
83+
lora_request=lora_request)
84+
for prompts, audios in inputs
85+
]
86+
87+
with hf_runner(model, dtype=dtype,
88+
auto_cls=AutoModelForSpeechSeq2Seq) as hf_model:
89+
90+
hf_processor = hf_model.processor
91+
eos_token_id = hf_processor.tokenizer.eos_token_id
92+
93+
hf_outputs_per_case = [
94+
hf_model.generate_greedy_logprobs_limit(prompts,
95+
max_tokens,
96+
num_logprobs=num_logprobs,
97+
audios=[audios],
98+
eos_token_id=eos_token_id)
99+
for prompts, audios in inputs
100+
]
101+
102+
for hf_outputs, vllm_outputs in zip(hf_outputs_per_case,
103+
vllm_outputs_per_case):
104+
check_logprobs_close(
105+
outputs_0_lst=hf_outputs,
106+
outputs_1_lst=[
107+
vllm_to_hf_output(output) for output in vllm_outputs
108+
],
109+
name_0="hf",
110+
name_1="vllm",
111+
)
112+
113+
114+
@pytest.mark.parametrize("model", models)
115+
@pytest.mark.parametrize("dtype", ["bfloat16"])
116+
@pytest.mark.parametrize("max_model_len", [2048])
117+
@pytest.mark.parametrize("max_tokens", [128])
118+
@pytest.mark.parametrize("num_logprobs", [10])
119+
def test_models(hf_runner, vllm_runner, model: str, audio_assets: _AudioAssets,
120+
dtype: str, max_model_len: int, max_tokens: int,
121+
num_logprobs: int) -> None:
122+
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
123+
model_info.check_available_online(on_fail="skip")
124+
model_info.check_transformers_version(on_fail="skip")
125+
126+
audio, sr = audio_assets[0].audio_and_sample_rate
127+
# This model expects 16k sample rate, which our test audio
128+
# already is; if this changes, it may break this test,
129+
# so we check it directly
130+
assert sr == 16000
131+
run_test(
132+
hf_runner,
133+
vllm_runner,
134+
[
135+
([HF_AUDIO_PROMPT], [audio]),
136+
],
137+
model,
138+
dtype=dtype,
139+
max_model_len=max_model_len,
140+
max_tokens=max_tokens,
141+
num_logprobs=num_logprobs,
142+
tensor_parallel_size=1,
143+
)

tests/models/decoder_only/audio_language/test_ultravox.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from vllm.multimodal.audio import resample_audio_librosa
1212
from vllm.sequence import SampleLogprobs
1313

14-
from ....conftest import HfRunner, VllmRunner
14+
from ....conftest import HfRunner, VllmRunner, _AudioAssets
1515
from ....utils import RemoteOpenAIServer
1616
from ...registry import HF_EXAMPLE_MODELS
1717
from ...utils import check_logprobs_close
@@ -31,12 +31,6 @@
3131
}
3232

3333

34-
@pytest.fixture(scope="session")
35-
def audio_assets():
36-
from vllm.assets.audio import AudioAsset
37-
return [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")]
38-
39-
4034
@pytest.fixture(scope="module", params=("mary_had_lamb", "winning_call"))
4135
def audio(request):
4236
from vllm.assets.audio import AudioAsset
@@ -59,7 +53,7 @@ def params_kwargs_to_cli_args(params_kwargs: dict[str, Any]) -> list[str]:
5953
pytest.param({}, marks=pytest.mark.cpu_model),
6054
pytest.param(CHUNKED_PREFILL_KWARGS),
6155
])
62-
def server(request, audio_assets):
56+
def server(request, audio_assets: _AudioAssets):
6357
args = [
6458
"--dtype", "bfloat16", "--max-model-len", "4096", "--enforce-eager",
6559
"--limit-mm-per-prompt",
@@ -230,8 +224,9 @@ def test_models(hf_runner, vllm_runner, audio, dtype: str, max_tokens: int,
230224
pytest.param({}, marks=pytest.mark.cpu_model),
231225
pytest.param(CHUNKED_PREFILL_KWARGS),
232226
])
233-
def test_models_with_multiple_audios(vllm_runner, audio_assets, dtype: str,
234-
max_tokens: int, num_logprobs: int,
227+
def test_models_with_multiple_audios(vllm_runner, audio_assets: _AudioAssets,
228+
dtype: str, max_tokens: int,
229+
num_logprobs: int,
235230
vllm_kwargs: dict) -> None:
236231

237232
vllm_prompt = _get_prompt(len(audio_assets),
@@ -250,7 +245,7 @@ def test_models_with_multiple_audios(vllm_runner, audio_assets, dtype: str,
250245

251246

252247
@pytest.mark.asyncio
253-
async def test_online_serving(client, audio_assets):
248+
async def test_online_serving(client, audio_assets: _AudioAssets):
254249
"""Exercises online serving with/without chunked prefill enabled."""
255250

256251
messages = [{

tests/models/multimodal/processing/test_common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,7 @@ def _test_processing_correctness_mistral(
254254
"adept/fuyu-8b",
255255
"google/gemma-3-4b-it",
256256
"THUDM/glm-4v-9b",
257+
"ibm-granite/granite-speech-3.3-8b",
257258
"h2oai/h2ovl-mississippi-800m",
258259
"OpenGVLab/InternVL2-1B",
259260
"HuggingFaceM4/Idefics3-8B-Llama3",

tests/models/registry.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,9 +298,11 @@ def check_available_online(
298298
extras={"fork": "Isotr0py/deepseek-vl2-tiny"}, # noqa: E501
299299
max_transformers_version="4.48", # noqa: E501
300300
transformers_version_reason="HF model is not compatible.", # noqa: E501
301-
hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]}), # noqa: E501
301+
hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]}), # noqa: E501
302302
"FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"),
303303
"Gemma3ForConditionalGeneration": _HfExamplesInfo("google/gemma-3-4b-it"),
304+
"GraniteSpeechForConditionalGeneration": _HfExamplesInfo("ibm-granite/granite-speech-3.3-8b", # noqa: E501
305+
min_transformers_version="4.52.0"), # noqa: E501
304306
"GLM4VForCausalLM": _HfExamplesInfo("THUDM/glm-4v-9b",
305307
trust_remote_code=True,
306308
hf_overrides={"architectures": ["GLM4VForCausalLM"]}), # noqa: E501

vllm/entrypoints/chat_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -517,7 +517,7 @@ def _placeholder_str(self, modality: ModalityStr,
517517

518518
raise TypeError(f"Unknown {modality} model type: {model_type}")
519519
elif modality == "audio":
520-
if model_type == "ultravox":
520+
if model_type in ("ultravox", "granite_speech"):
521521
return "<|audio|>"
522522
if model_type == "phi4mm":
523523
return f"<|audio_{current_count}|>"

0 commit comments

Comments
 (0)