Skip to content

Commit d88506d

Browse files
authored
[Model] LoRA Support for Ultravox model (#11253)
1 parent 9cdea30 commit d88506d

File tree

4 files changed

+160
-7
lines changed

4 files changed

+160
-7
lines changed

docs/source/models/supported_models.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -857,7 +857,7 @@ See [this page](#generative-models) for more information on how to use generativ
857857
* Ultravox
858858
* T + A<sup>E+</sup>
859859
* `fixie-ai/ultravox-v0_3`
860-
*
860+
* ✅︎
861861
* ✅︎
862862
* ✅︎
863863
:::

tests/conftest.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -737,14 +737,16 @@ def generate(
737737
images: Optional[PromptImageInput] = None,
738738
videos: Optional[PromptVideoInput] = None,
739739
audios: Optional[PromptAudioInput] = None,
740+
**kwargs: Any,
740741
) -> List[Tuple[List[List[int]], List[str]]]:
741742
inputs = self.get_inputs(prompts,
742743
images=images,
743744
videos=videos,
744745
audios=audios)
745746

746747
req_outputs = self.model.generate(inputs,
747-
sampling_params=sampling_params)
748+
sampling_params=sampling_params,
749+
**kwargs)
748750

749751
outputs: List[Tuple[List[List[int]], List[str]]] = []
750752
for req_output in req_outputs:
@@ -782,6 +784,7 @@ def generate_w_logprobs(
782784
images: Optional[PromptImageInput] = None,
783785
audios: Optional[PromptAudioInput] = None,
784786
videos: Optional[PromptVideoInput] = None,
787+
**kwargs: Any,
785788
) -> Union[List[TokensTextLogprobs],
786789
List[TokensTextLogprobsPromptLogprobs]]:
787790
inputs = self.get_inputs(prompts,
@@ -790,7 +793,8 @@ def generate_w_logprobs(
790793
audios=audios)
791794

792795
req_outputs = self.model.generate(inputs,
793-
sampling_params=sampling_params)
796+
sampling_params=sampling_params,
797+
**kwargs)
794798

795799
toks_str_logsprobs_prompt_logprobs = (
796800
self._final_steps_generate_w_logprobs(req_outputs))
@@ -826,13 +830,15 @@ def generate_greedy(
826830
images: Optional[PromptImageInput] = None,
827831
videos: Optional[PromptVideoInput] = None,
828832
audios: Optional[PromptAudioInput] = None,
833+
**kwargs: Any,
829834
) -> List[Tuple[List[int], str]]:
830835
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
831836
outputs = self.generate(prompts,
832837
greedy_params,
833838
images=images,
834839
videos=videos,
835-
audios=audios)
840+
audios=audios,
841+
**kwargs)
836842
return [(output_ids[0], output_str[0])
837843
for output_ids, output_str in outputs]
838844

@@ -847,6 +853,7 @@ def generate_greedy_logprobs(
847853
videos: Optional[PromptVideoInput] = None,
848854
stop_token_ids: Optional[List[int]] = None,
849855
stop: Optional[List[str]] = None,
856+
**kwargs: Any,
850857
) -> Union[List[TokensTextLogprobs],
851858
List[TokensTextLogprobsPromptLogprobs]]:
852859
greedy_logprobs_params = SamplingParams(
@@ -861,7 +868,8 @@ def generate_greedy_logprobs(
861868
greedy_logprobs_params,
862869
images=images,
863870
audios=audios,
864-
videos=videos)
871+
videos=videos,
872+
**kwargs)
865873

866874
def generate_encoder_decoder_greedy_logprobs(
867875
self,

tests/lora/test_ultravox.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
import shutil
2+
from os import path
3+
from tempfile import TemporaryDirectory
4+
from typing import List, Tuple
5+
6+
import torch
7+
from huggingface_hub import snapshot_download
8+
from safetensors.torch import load_file, save_file
9+
from transformers import AutoTokenizer
10+
11+
from vllm.lora.request import LoRARequest
12+
13+
from ..models.utils import check_outputs_equal
14+
15+
ULTRAVOX_MODEL_NAME = "fixie-ai/ultravox-v0_3"
16+
LLMA_MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
17+
18+
VLLM_PLACEHOLDER = "<|reserved_special_token_0|>"
19+
20+
PROMPT = "Tell me about a Fool's mate move in 20 words. Provide the moves!"
21+
22+
23+
def llama3_1_8b_chess_lora_path():
24+
return snapshot_download(
25+
repo_id="mkopecki/chess-lora-adapter-llama-3.1-8b")
26+
27+
28+
# can't use llama lora adapter without module name transformation
29+
# because ultravox nest language model
30+
def transform_module_names_for_ultravox(state_dict):
31+
transformed_state_dict = {}
32+
for key, value in state_dict.items():
33+
new_key = key.replace("base_model.model",
34+
"base_model.model.language_model")
35+
transformed_state_dict[new_key] = value
36+
return transformed_state_dict
37+
38+
39+
def mk_llama3_1_8b_ultravox_chess_lora(source_repo, target_path):
40+
tensor_file = "adapter_model.safetensors"
41+
state_dict = load_file(path.join(source_repo, tensor_file))
42+
transformed_state_dict = transform_module_names_for_ultravox(state_dict)
43+
44+
save_file(transformed_state_dict, path.join(target_path, tensor_file))
45+
46+
config_file = "adapter_config.json"
47+
shutil.copyfile(path.join(source_repo, config_file),
48+
path.join(target_path, config_file))
49+
return target_path
50+
51+
52+
def _get_prompt(audio_count, question, placeholder, model_name) -> str:
53+
tokenizer = AutoTokenizer.from_pretrained(model_name)
54+
placeholder = f"{placeholder}\n" * audio_count
55+
56+
return tokenizer.apply_chat_template([{
57+
'role': 'user',
58+
'content': f"{placeholder}{question}"
59+
}],
60+
tokenize=False,
61+
add_generation_prompt=True)
62+
63+
64+
def test_ultravox_lora(vllm_runner):
65+
"""
66+
TODO: Train an Ultravox LoRA instead of using a Llama LoRA.
67+
"""
68+
# Workaround to prevent device mismatch in Whisper.
69+
# Can be removed when it is fixed upstream in transformer
70+
# https://github.com/huggingface/transformers/pull/35866
71+
torch.set_default_device("cpu")
72+
73+
llama3_1_8b_chess_lora = llama3_1_8b_chess_lora_path()
74+
with TemporaryDirectory() as temp_ultravox_lora_dir:
75+
llama3_1_8b_ultravox_chess_lora = mk_llama3_1_8b_ultravox_chess_lora(
76+
llama3_1_8b_chess_lora, temp_ultravox_lora_dir)
77+
with vllm_runner(
78+
ULTRAVOX_MODEL_NAME,
79+
enforce_eager=True,
80+
max_num_seqs=2,
81+
enable_lora=True,
82+
max_loras=1,
83+
max_lora_rank=128,
84+
dtype="bfloat16",
85+
max_model_len=1024,
86+
) as vllm_model:
87+
ultravox_outputs: List[Tuple[
88+
List[int], str]] = vllm_model.generate_greedy(
89+
[
90+
_get_prompt(0, PROMPT, VLLM_PLACEHOLDER,
91+
ULTRAVOX_MODEL_NAME)
92+
],
93+
256,
94+
lora_request=LoRARequest(str(1), 1,
95+
llama3_1_8b_ultravox_chess_lora),
96+
)
97+
98+
# run llama with and without lora to compare outputs with above
99+
with vllm_runner(
100+
LLMA_MODEL_NAME,
101+
enforce_eager=True,
102+
max_num_seqs=2,
103+
enable_lora=True,
104+
max_loras=1,
105+
max_lora_rank=128,
106+
dtype="bfloat16",
107+
max_model_len=1024,
108+
) as vllm_model:
109+
llama_outputs: List[Tuple[List[int], str]] = (
110+
vllm_model.generate_greedy(
111+
[_get_prompt(0, PROMPT, VLLM_PLACEHOLDER, LLMA_MODEL_NAME)],
112+
256,
113+
lora_request=LoRARequest(str(1), 1, llama3_1_8b_chess_lora),
114+
))
115+
116+
check_outputs_equal(
117+
outputs_0_lst=ultravox_outputs,
118+
outputs_1_lst=llama_outputs,
119+
name_0="ultravox",
120+
name_1="llama",
121+
)

vllm/model_executor/models/ultravox.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from vllm.model_executor.layers.layernorm import RMSNorm
2323
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
2424
from vllm.model_executor.model_loader.loader import DefaultModelLoader
25+
from vllm.model_executor.models.module_mapping import MultiModelKeys
2526
from vllm.model_executor.sampling_metadata import SamplingMetadata
2627
from vllm.multimodal import MULTIMODAL_REGISTRY
2728
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
@@ -33,7 +34,7 @@
3334
from vllm.sequence import IntermediateTensors
3435
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
3536

36-
from .interfaces import SupportsMultiModal, SupportsPP
37+
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
3738
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
3839
init_vllm_registered_model, maybe_prefix,
3940
merge_multimodal_embeddings,
@@ -343,7 +344,20 @@ def forward(
343344
UltravoxMultiModalProcessor,
344345
info=UltravoxProcessingInfo,
345346
dummy_inputs=UltravoxDummyInputsBuilder)
346-
class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP):
347+
class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):
348+
349+
packed_modules_mapping = {
350+
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
351+
"gate_up_proj": ["gate_proj", "up_proj"]
352+
}
353+
354+
# LoRA specific attributes
355+
# TODO : Add LoRA to the audio tower and projector.
356+
supported_lora_modules = [
357+
"qkv_proj", "o_proj", "gate_up_proj", "down_proj"
358+
]
359+
embedding_modules = {}
360+
embedding_padding_modules = []
347361

348362
hf_to_vllm_mapper = WeightsMapper(
349363
orig_to_new_prefix={"audio_tower.model.encoder.": "audio_tower."})
@@ -391,6 +405,16 @@ def sampler(self):
391405

392406
return get_sampler()
393407

408+
def get_mm_mapping(self) -> MultiModelKeys:
409+
"""
410+
Get the module prefix in multimodal models
411+
"""
412+
return MultiModelKeys.from_string_field(
413+
language_model="language_model.",
414+
connector="multi_modal_projector.",
415+
tower_model="audio_tower.",
416+
)
417+
394418
def _audio_features_to_embeddings(
395419
self, input_features: torch.Tensor) -> torch.Tensor:
396420
audio_input = input_features.to(self.audio_tower.dtype)

0 commit comments

Comments
 (0)