Skip to content

[Model] LoRA Support for Ultravox model #11253

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
1c55938
WIP: early draft of lora support in Ultravox
thedebugger Nov 18, 2024
5a6b79f
format fixes
thedebugger Nov 19, 2024
3f5996c
Fix lora modules and formatting
thedebugger Dec 17, 2024
d1b65eb
Merge branch 'main' of https://github.com/vllm-project/vllm into svij…
jeejeelee Jan 6, 2025
7367bc2
Merge branch 'main' of https://github.com/vllm-project/vllm into svij…
jeejeelee Jan 6, 2025
2abf2ab
Done
jeejeelee Jan 6, 2025
be87788
Address code review feedback
thedebugger Jan 10, 2025
317fc38
Merge branch 'main' into svij-ultravox-lora-dec-16
thedebugger Jan 11, 2025
4a633d3
Fix formatting and test case
thedebugger Jan 11, 2025
224a65e
Merge branch 'main' of https://github.com/vllm-project/vllm into svij…
jeejeelee Jan 16, 2025
769f7bd
Done
jeejeelee Jan 16, 2025
907b3c7
Merge branch 'main' of https://github.com/vllm-project/vllm into svij…
jeejeelee Jan 16, 2025
208e662
Add doc
jeejeelee Jan 16, 2025
1248d5f
Merge branch 'main' of https://github.com/vllm-project/vllm into svij…
jeejeelee Jan 18, 2025
575b5dc
Merge branch 'main' of https://github.com/vllm-project/vllm into svij…
jeejeelee Jan 19, 2025
7cb7eba
Merge branch 'main' of https://github.com/vllm-project/vllm into svij…
jeejeelee Jan 20, 2025
f483d9a
Optmize unit test
jeejeelee Jan 20, 2025
1976ee0
Test setting cpu as a default device
thedebugger Jan 22, 2025
80fb1b8
Merge remote-tracking branch 'origin/main' into svij-ultravox-lora-de…
thedebugger Jan 27, 2025
c5cdde7
Merge remote-tracking branch 'origin/main' into svij-ultravox-lora-de…
thedebugger Jan 28, 2025
0b18650
Merge remote-tracking branch 'origin/main' into svij-ultravox-lora-de…
thedebugger Jan 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -767,7 +767,7 @@ See [this page](#generative-models) for more information on how to use generativ
* Ultravox
* T + A<sup>E+</sup>
* `fixie-ai/ultravox-v0_3`
*
* ✅︎
* ✅︎
* ✅︎
:::
Expand Down
16 changes: 12 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,14 +735,16 @@ def generate(
images: Optional[PromptImageInput] = None,
videos: Optional[PromptVideoInput] = None,
audios: Optional[PromptAudioInput] = None,
**kwargs: Any,
) -> List[Tuple[List[List[int]], List[str]]]:
inputs = self.get_inputs(prompts,
images=images,
videos=videos,
audios=audios)

req_outputs = self.model.generate(inputs,
sampling_params=sampling_params)
sampling_params=sampling_params,
**kwargs)

outputs: List[Tuple[List[List[int]], List[str]]] = []
for req_output in req_outputs:
Expand Down Expand Up @@ -780,6 +782,7 @@ def generate_w_logprobs(
images: Optional[PromptImageInput] = None,
audios: Optional[PromptAudioInput] = None,
videos: Optional[PromptVideoInput] = None,
**kwargs: Any,
) -> Union[List[TokensTextLogprobs],
List[TokensTextLogprobsPromptLogprobs]]:
inputs = self.get_inputs(prompts,
Expand All @@ -788,7 +791,8 @@ def generate_w_logprobs(
audios=audios)

req_outputs = self.model.generate(inputs,
sampling_params=sampling_params)
sampling_params=sampling_params,
**kwargs)

toks_str_logsprobs_prompt_logprobs = (
self._final_steps_generate_w_logprobs(req_outputs))
Expand Down Expand Up @@ -824,13 +828,15 @@ def generate_greedy(
images: Optional[PromptImageInput] = None,
videos: Optional[PromptVideoInput] = None,
audios: Optional[PromptAudioInput] = None,
**kwargs: Any,
) -> List[Tuple[List[int], str]]:
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
outputs = self.generate(prompts,
greedy_params,
images=images,
videos=videos,
audios=audios)
audios=audios,
**kwargs)
return [(output_ids[0], output_str[0])
for output_ids, output_str in outputs]

Expand All @@ -845,6 +851,7 @@ def generate_greedy_logprobs(
videos: Optional[PromptVideoInput] = None,
stop_token_ids: Optional[List[int]] = None,
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Union[List[TokensTextLogprobs],
List[TokensTextLogprobsPromptLogprobs]]:
greedy_logprobs_params = SamplingParams(
Expand All @@ -859,7 +866,8 @@ def generate_greedy_logprobs(
greedy_logprobs_params,
images=images,
audios=audios,
videos=videos)
videos=videos,
**kwargs)

def generate_encoder_decoder_greedy_logprobs(
self,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes to this file are not related to this PR, please revert.

Expand Down
121 changes: 121 additions & 0 deletions tests/lora/test_ultravox.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import shutil
from os import path
from tempfile import TemporaryDirectory
from typing import List, Tuple

import torch
from huggingface_hub import snapshot_download
from safetensors.torch import load_file, save_file
from transformers import AutoTokenizer

from vllm.lora.request import LoRARequest

from ..models.utils import check_outputs_equal

ULTRAVOX_MODEL_NAME = "fixie-ai/ultravox-v0_3"
LLMA_MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"

VLLM_PLACEHOLDER = "<|reserved_special_token_0|>"

PROMPT = "Tell me about a Fool's mate move in 20 words. Provide the moves!"


def llama3_1_8b_chess_lora_path():
return snapshot_download(
repo_id="mkopecki/chess-lora-adapter-llama-3.1-8b")


# can't use llama lora adapter without module name transformation
# because ultravox nest language model
def transform_module_names_for_ultravox(state_dict):
transformed_state_dict = {}
for key, value in state_dict.items():
new_key = key.replace("base_model.model",
"base_model.model.language_model")
transformed_state_dict[new_key] = value
return transformed_state_dict


def mk_llama3_1_8b_ultravox_chess_lora(source_repo, target_path):
tensor_file = "adapter_model.safetensors"
state_dict = load_file(path.join(source_repo, tensor_file))
transformed_state_dict = transform_module_names_for_ultravox(state_dict)

save_file(transformed_state_dict, path.join(target_path, tensor_file))

config_file = "adapter_config.json"
shutil.copyfile(path.join(source_repo, config_file),
path.join(target_path, config_file))
return target_path


def _get_prompt(audio_count, question, placeholder, model_name) -> str:
tokenizer = AutoTokenizer.from_pretrained(model_name)
placeholder = f"{placeholder}\n" * audio_count

return tokenizer.apply_chat_template([{
'role': 'user',
'content': f"{placeholder}{question}"
}],
tokenize=False,
add_generation_prompt=True)


def test_ultravox_lora(vllm_runner):
"""
TODO: Train an Ultravox LoRA instead of using a Llama LoRA.
"""
# Workaround to prevent device mismatch in Whisper.
# Can be removed when it is fixed upstream in transformer
# https://github.com/huggingface/transformers/pull/35866
torch.set_default_device("cpu")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's too hacky, don't do it. This significantly increased the CI testing time.

Copy link
Contributor Author

@thedebugger thedebugger Jan 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked the time before and after: delta was ~4 mins. Knowing this works, let me check if I can find a better fix for this

Source

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So it is likely a bug in transformer's whisper. I've opened a PR to fix that.

Workaround isn't ideal but impact is limited IMO given the passing test takes roughly 2-3 mins anyways. Moreover, default device is used only when device is not explicitly passed in functions param which is why we aren't seeing much impact

I also tried few more options like passing device all way through from vllm -> ultravox -> whisper but it is lot more complicated and require more changes in bunch of places. So I think it is okay to merge for now and I can clean up when it is fixed in transformer. Sounds okay?

Copy link
Contributor Author

@thedebugger thedebugger Jan 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jeejeelee okay to merge if tests are passing?

Btw, I verified that device bug exists upstream and I'm working on getting that patched up

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @jeejeelee is set_default_device a blocker? Let me know how do you want to move forward


llama3_1_8b_chess_lora = llama3_1_8b_chess_lora_path()
with TemporaryDirectory() as temp_ultravox_lora_dir:
llama3_1_8b_ultravox_chess_lora = mk_llama3_1_8b_ultravox_chess_lora(
llama3_1_8b_chess_lora, temp_ultravox_lora_dir)
with vllm_runner(
ULTRAVOX_MODEL_NAME,
enforce_eager=True,
max_num_seqs=2,
enable_lora=True,
max_loras=1,
max_lora_rank=128,
dtype="bfloat16",
max_model_len=1024,
) as vllm_model:
ultravox_outputs: List[Tuple[
List[int], str]] = vllm_model.generate_greedy(
[
_get_prompt(0, PROMPT, VLLM_PLACEHOLDER,
ULTRAVOX_MODEL_NAME)
],
256,
lora_request=LoRARequest(str(1), 1,
llama3_1_8b_ultravox_chess_lora),
)

# run llama with and without lora to compare outputs with above
with vllm_runner(
LLMA_MODEL_NAME,
enforce_eager=True,
max_num_seqs=2,
enable_lora=True,
max_loras=1,
max_lora_rank=128,
dtype="bfloat16",
max_model_len=1024,
) as vllm_model:
llama_outputs: List[Tuple[List[int], str]] = (
vllm_model.generate_greedy(
[_get_prompt(0, PROMPT, VLLM_PLACEHOLDER, LLMA_MODEL_NAME)],
256,
lora_request=LoRARequest(str(1), 1, llama3_1_8b_chess_lora),
))

check_outputs_equal(
outputs_0_lst=ultravox_outputs,
outputs_1_lst=llama_outputs,
name_0="ultravox",
name_1="llama",
)
28 changes: 26 additions & 2 deletions vllm/model_executor/models/ultravox.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.model_loader.loader import DefaultModelLoader
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
Expand All @@ -31,7 +32,7 @@
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.ultravox import UltravoxConfig

from .interfaces import SupportsMultiModal, SupportsPP
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings,
Expand Down Expand Up @@ -337,7 +338,20 @@ def forward(
UltravoxMultiModalProcessor,
info=UltravoxProcessingInfo,
dummy_inputs=UltravoxDummyInputsBuilder)
class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP):
class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA):

packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"]
}

# LoRA specific attributes
# TODO : Add LoRA to the audio tower and projector.
supported_lora_modules = [
"qkv_proj", "o_proj", "gate_up_proj", "down_proj"
]
embedding_modules = {}
embedding_padding_modules = []

hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={"audio_tower.model.encoder.": "audio_tower."})
Expand Down Expand Up @@ -385,6 +399,16 @@ def sampler(self):

return get_sampler()

def get_mm_mapping(self) -> MultiModelKeys:
"""
Get the module prefix in multimodal models
"""
return MultiModelKeys.from_string_field(
language_model="language_model.",
connector="multi_modal_projector.",
tower_model="audio_tower.",
)

def _audio_features_to_embeddings(
self, input_features: torch.Tensor) -> torch.Tensor:
audio_input = input_features.to(self.audio_tower.dtype)
Expand Down