-
-
Notifications
You must be signed in to change notification settings - Fork 8.4k
[Model] LoRA Support for Ultravox model #11253
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
1c55938
5a6b79f
3f5996c
d1b65eb
7367bc2
2abf2ab
be87788
317fc38
4a633d3
224a65e
769f7bd
907b3c7
208e662
1248d5f
575b5dc
7cb7eba
f483d9a
1976ee0
80fb1b8
c5cdde7
0b18650
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
import shutil | ||
from os import path | ||
from tempfile import TemporaryDirectory | ||
from typing import List, Tuple | ||
|
||
import torch | ||
from huggingface_hub import snapshot_download | ||
from safetensors.torch import load_file, save_file | ||
from transformers import AutoTokenizer | ||
|
||
from vllm.lora.request import LoRARequest | ||
|
||
from ..models.utils import check_outputs_equal | ||
|
||
ULTRAVOX_MODEL_NAME = "fixie-ai/ultravox-v0_3" | ||
LLMA_MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct" | ||
|
||
VLLM_PLACEHOLDER = "<|reserved_special_token_0|>" | ||
|
||
PROMPT = "Tell me about a Fool's mate move in 20 words. Provide the moves!" | ||
|
||
|
||
def llama3_1_8b_chess_lora_path(): | ||
return snapshot_download( | ||
repo_id="mkopecki/chess-lora-adapter-llama-3.1-8b") | ||
|
||
|
||
# can't use llama lora adapter without module name transformation | ||
# because ultravox nest language model | ||
def transform_module_names_for_ultravox(state_dict): | ||
transformed_state_dict = {} | ||
for key, value in state_dict.items(): | ||
new_key = key.replace("base_model.model", | ||
"base_model.model.language_model") | ||
transformed_state_dict[new_key] = value | ||
return transformed_state_dict | ||
|
||
|
||
def mk_llama3_1_8b_ultravox_chess_lora(source_repo, target_path): | ||
tensor_file = "adapter_model.safetensors" | ||
state_dict = load_file(path.join(source_repo, tensor_file)) | ||
transformed_state_dict = transform_module_names_for_ultravox(state_dict) | ||
|
||
save_file(transformed_state_dict, path.join(target_path, tensor_file)) | ||
|
||
config_file = "adapter_config.json" | ||
shutil.copyfile(path.join(source_repo, config_file), | ||
path.join(target_path, config_file)) | ||
return target_path | ||
|
||
|
||
def _get_prompt(audio_count, question, placeholder, model_name) -> str: | ||
tokenizer = AutoTokenizer.from_pretrained(model_name) | ||
placeholder = f"{placeholder}\n" * audio_count | ||
|
||
return tokenizer.apply_chat_template([{ | ||
'role': 'user', | ||
'content': f"{placeholder}{question}" | ||
}], | ||
tokenize=False, | ||
add_generation_prompt=True) | ||
|
||
|
||
def test_ultravox_lora(vllm_runner): | ||
""" | ||
TODO: Train an Ultravox LoRA instead of using a Llama LoRA. | ||
""" | ||
# Workaround to prevent device mismatch in Whisper. | ||
# Can be removed when it is fixed upstream in transformer | ||
# https://github.com/huggingface/transformers/pull/35866 | ||
torch.set_default_device("cpu") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's too hacky, don't do it. This significantly increased the CI testing time. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I checked the time before and after: delta was ~4 mins. Knowing this works, let me check if I can find a better fix for this Source There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So it is likely a bug in transformer's whisper. I've opened a PR to fix that. Workaround isn't ideal but impact is limited IMO given the passing test takes roughly 2-3 mins anyways. Moreover, default device is used only when device is not explicitly passed in functions param which is why we aren't seeing much impact I also tried few more options like passing device all way through from vllm -> ultravox -> whisper but it is lot more complicated and require more changes in bunch of places. So I think it is okay to merge for now and I can clean up when it is fixed in transformer. Sounds okay? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @jeejeelee okay to merge if tests are passing? Btw, I verified that device bug exists upstream and I'm working on getting that patched up There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi @jeejeelee is set_default_device a blocker? Let me know how do you want to move forward |
||
|
||
llama3_1_8b_chess_lora = llama3_1_8b_chess_lora_path() | ||
with TemporaryDirectory() as temp_ultravox_lora_dir: | ||
llama3_1_8b_ultravox_chess_lora = mk_llama3_1_8b_ultravox_chess_lora( | ||
llama3_1_8b_chess_lora, temp_ultravox_lora_dir) | ||
with vllm_runner( | ||
ULTRAVOX_MODEL_NAME, | ||
enforce_eager=True, | ||
max_num_seqs=2, | ||
enable_lora=True, | ||
max_loras=1, | ||
max_lora_rank=128, | ||
dtype="bfloat16", | ||
max_model_len=1024, | ||
) as vllm_model: | ||
ultravox_outputs: List[Tuple[ | ||
List[int], str]] = vllm_model.generate_greedy( | ||
[ | ||
_get_prompt(0, PROMPT, VLLM_PLACEHOLDER, | ||
ULTRAVOX_MODEL_NAME) | ||
], | ||
256, | ||
lora_request=LoRARequest(str(1), 1, | ||
llama3_1_8b_ultravox_chess_lora), | ||
) | ||
|
||
# run llama with and without lora to compare outputs with above | ||
with vllm_runner( | ||
LLMA_MODEL_NAME, | ||
enforce_eager=True, | ||
max_num_seqs=2, | ||
enable_lora=True, | ||
max_loras=1, | ||
max_lora_rank=128, | ||
dtype="bfloat16", | ||
max_model_len=1024, | ||
) as vllm_model: | ||
llama_outputs: List[Tuple[List[int], str]] = ( | ||
vllm_model.generate_greedy( | ||
[_get_prompt(0, PROMPT, VLLM_PLACEHOLDER, LLMA_MODEL_NAME)], | ||
256, | ||
lora_request=LoRARequest(str(1), 1, llama3_1_8b_chess_lora), | ||
)) | ||
|
||
check_outputs_equal( | ||
thedebugger marked this conversation as resolved.
Show resolved
Hide resolved
|
||
outputs_0_lst=ultravox_outputs, | ||
outputs_1_lst=llama_outputs, | ||
name_0="ultravox", | ||
name_1="llama", | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The changes to this file are not related to this PR, please revert.