Skip to content

Commit 3f5996c

Browse files
committed
Fix lora modules and formatting
Remove stale comment Add llama lora modules Add llama test case Add test case and log warning on missing lora modules Rollback unwanted changes and format fixes Signed-off-by: Sumit Vij <sumitvij11+github@gmail.com>
1 parent 5a6b79f commit 3f5996c

File tree

6 files changed

+128
-99
lines changed

6 files changed

+128
-99
lines changed

tests/conftest.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -733,14 +733,16 @@ def generate(
733733
images: Optional[PromptImageInput] = None,
734734
videos: Optional[PromptVideoInput] = None,
735735
audios: Optional[PromptAudioInput] = None,
736+
**kwargs: Any,
736737
) -> List[Tuple[List[List[int]], List[str]]]:
737738
inputs = self.get_inputs(prompts,
738739
images=images,
739740
videos=videos,
740741
audios=audios)
741742

742743
req_outputs = self.model.generate(inputs,
743-
sampling_params=sampling_params)
744+
sampling_params=sampling_params,
745+
**kwargs)
744746

745747
outputs: List[Tuple[List[List[int]], List[str]]] = []
746748
for req_output in req_outputs:
@@ -778,6 +780,7 @@ def generate_w_logprobs(
778780
images: Optional[PromptImageInput] = None,
779781
audios: Optional[PromptAudioInput] = None,
780782
videos: Optional[PromptVideoInput] = None,
783+
**kwargs: Any,
781784
) -> Union[List[TokensTextLogprobs],
782785
List[TokensTextLogprobsPromptLogprobs]]:
783786
inputs = self.get_inputs(prompts,
@@ -786,7 +789,8 @@ def generate_w_logprobs(
786789
audios=audios)
787790

788791
req_outputs = self.model.generate(inputs,
789-
sampling_params=sampling_params)
792+
sampling_params=sampling_params,
793+
**kwargs)
790794

791795
toks_str_logsprobs_prompt_logprobs = (
792796
self._final_steps_generate_w_logprobs(req_outputs))
@@ -822,13 +826,15 @@ def generate_greedy(
822826
images: Optional[PromptImageInput] = None,
823827
videos: Optional[PromptVideoInput] = None,
824828
audios: Optional[PromptAudioInput] = None,
829+
**kwargs: Any,
825830
) -> List[Tuple[List[int], str]]:
826831
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
827832
outputs = self.generate(prompts,
828833
greedy_params,
829834
images=images,
830835
videos=videos,
831-
audios=audios)
836+
audios=audios,
837+
**kwargs)
832838
return [(output_ids[0], output_str[0])
833839
for output_ids, output_str in outputs]
834840

@@ -843,6 +849,7 @@ def generate_greedy_logprobs(
843849
videos: Optional[PromptVideoInput] = None,
844850
stop_token_ids: Optional[List[int]] = None,
845851
stop: Optional[List[str]] = None,
852+
**kwargs: Any,
846853
) -> Union[List[TokensTextLogprobs],
847854
List[TokensTextLogprobsPromptLogprobs]]:
848855
greedy_logprobs_params = SamplingParams(
@@ -857,7 +864,8 @@ def generate_greedy_logprobs(
857864
greedy_logprobs_params,
858865
images=images,
859866
audios=audios,
860-
videos=videos)
867+
videos=videos,
868+
**kwargs)
861869

862870
def generate_encoder_decoder_greedy_logprobs(
863871
self,

tests/lora/conftest.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,18 +147,29 @@ def sql_lora_huggingface_id():
147147
# huggingface repo id is used to test lora runtime downloading.
148148
return "yard1/llama-2-7b-sql-lora-test"
149149

150+
150151
@pytest.fixture(scope="session")
151152
def sql_lora_files(sql_lora_huggingface_id):
152153
return snapshot_download(repo_id=sql_lora_huggingface_id)
153154

155+
154156
@pytest.fixture(scope="session")
155157
def llama3_1_8b_chess_lora():
156-
return snapshot_download(repo_id="mkopecki/chess-lora-adapter-llama-3.1-8b")
158+
return snapshot_download(
159+
repo_id="mkopecki/chess-lora-adapter-llama-3.1-8b")
160+
161+
162+
@pytest.fixture(scope="session")
163+
def llama3_1_8b_ultravox_chess_lora():
164+
# ultravox chess lora is result of transformation of above chess llama lora
165+
return snapshot_download(repo_id="thedebugger11/ultravox-chess-lora")
166+
157167

158168
@pytest.fixture(scope="session")
159169
def lora_bias_files():
160170
return snapshot_download(repo_id="followumesh/granite-3b-lora8-bias")
161171

172+
162173
@pytest.fixture(scope="session")
163174
def mixtral_lora_files():
164175
# Note: this module has incorrect adapter_config.json to test
@@ -214,6 +225,7 @@ def baichuan_zero_lora_files():
214225
# all the lora_B weights are initialized to zero.
215226
return snapshot_download(repo_id="jeeejeee/baichuan7b-zero-init")
216227

228+
217229
@pytest.fixture(scope="session")
218230
def baichuan_regex_lora_files():
219231
return snapshot_download(repo_id="jeeejeee/baichuan-7b-lora-zero-regex")
@@ -223,6 +235,7 @@ def baichuan_regex_lora_files():
223235
def minicpmv_lora_files():
224236
return snapshot_download(repo_id="jeeejeee/minicpmv25-lora-pokemon")
225237

238+
226239
@pytest.fixture(scope="session")
227240
def qwen2vl_lora_files():
228241
return snapshot_download(repo_id="jeeejeee/qwen2-vl-lora-pokemon")
@@ -232,6 +245,7 @@ def qwen2vl_lora_files():
232245
def tinyllama_lora_files():
233246
return snapshot_download(repo_id="jashing/tinyllama-colorist-lora")
234247

248+
235249
@pytest.fixture(scope="session")
236250
def phi2_lora_files():
237251
return snapshot_download(repo_id="isotr0py/phi-2-test-sql-lora")

tests/lora/test_ultravox.py

Lines changed: 76 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,21 @@
1+
from typing import List, Tuple
12

2-
from typing import List
3+
from transformers import AutoTokenizer
34

4-
import pytest
5-
6-
import vllm
7-
8-
from transformers import AutoTokenizer
95
from vllm.lora.request import LoRARequest
10-
from vllm.platforms import current_platform
116

12-
MODEL_NAME = "fixie-ai/ultravox-v0_3"
7+
from ..models.utils import check_outputs_equal
8+
9+
ULTRAVOX_MODEL_NAME = "fixie-ai/ultravox-v0_3"
10+
LLMA_MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
1311

1412
VLLM_PLACEHOLDER = "<|reserved_special_token_0|>"
1513

16-
EXPECTED_OUTPUT = [
17-
"Fool mate"
18-
]
14+
PROMPT = "Tell me about a silly chess move in 20 words"
15+
1916

20-
def _get_prompt(audio_count, question, placeholder):
21-
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
17+
def _get_prompt(audio_count, question, placeholder, model_name) -> str:
18+
tokenizer = AutoTokenizer.from_pretrained(model_name)
2219
placeholder = f"{placeholder}\n" * audio_count
2320

2421
return tokenizer.apply_chat_template([{
@@ -28,44 +25,74 @@ def _get_prompt(audio_count, question, placeholder):
2825
tokenize=False,
2926
add_generation_prompt=True)
3027

31-
def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
32-
sampling_params = vllm.SamplingParams(
33-
temperature=0,
34-
max_tokens=1000,
35-
)
3628

37-
inputs = [{
38-
"prompt":_get_prompt(1, "Tell me about a silly chess move in 20 words", VLLM_PLACEHOLDER),
39-
}]
29+
def test_ultravox_lora(vllm_runner, llama3_1_8b_chess_lora,
30+
llama3_1_8b_ultravox_chess_lora):
31+
with vllm_runner(
32+
ULTRAVOX_MODEL_NAME,
33+
enforce_eager=True,
34+
max_num_seqs=128,
35+
enable_lora=True,
36+
max_loras=4,
37+
max_lora_rank=128,
38+
dtype="bfloat16",
39+
max_model_len=4096,
40+
) as vllm_model:
41+
ultravox_outputs: List[Tuple[List[int],
42+
str]] = vllm_model.generate_greedy(
43+
[
44+
_get_prompt(
45+
0, PROMPT, VLLM_PLACEHOLDER,
46+
ULTRAVOX_MODEL_NAME)
47+
],
48+
256,
49+
lora_request=LoRARequest(
50+
str(1), 1,
51+
llama3_1_8b_ultravox_chess_lora),
52+
)
53+
54+
# run llama with and without lora to compare outputs with above
55+
with vllm_runner(
56+
LLMA_MODEL_NAME,
57+
enforce_eager=True,
58+
max_num_seqs=128,
59+
enable_lora=True,
60+
max_loras=4,
61+
max_lora_rank=128,
62+
dtype="bfloat16",
63+
max_model_len=4096,
64+
) as vllm_model:
65+
llama_outputs_no_lora: List[Tuple[List[int],
66+
str]] = vllm_model.generate_greedy(
67+
[
68+
_get_prompt(
69+
0, PROMPT,
70+
VLLM_PLACEHOLDER,
71+
LLMA_MODEL_NAME)
72+
],
73+
256,
74+
)
75+
llama_outputs: List[Tuple[List[int],
76+
str]] = vllm_model.generate_greedy(
77+
[
78+
_get_prompt(0, PROMPT,
79+
VLLM_PLACEHOLDER,
80+
LLMA_MODEL_NAME)
81+
],
82+
256,
83+
lora_request=LoRARequest(
84+
str(1), 1, llama3_1_8b_chess_lora),
85+
)
4086

41-
outputs = llm.generate(
42-
inputs,
43-
sampling_params,
44-
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
45-
if lora_id else None,
87+
check_outputs_equal(
88+
outputs_0_lst=ultravox_outputs,
89+
outputs_1_lst=llama_outputs,
90+
name_0="ultravox",
91+
name_1="llama",
4692
)
47-
generated_texts: List[str] = []
48-
for output in outputs:
49-
prompt = output.prompt
50-
generated_text = output.outputs[0].text.strip()
51-
generated_texts.append(generated_text)
52-
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
53-
return generated_texts
5493

94+
_, llama_no_lora_str = llama_outputs_no_lora[0]
95+
_, ultravox_str = ultravox_outputs[0]
5596

56-
def test_fixie_lora(llama3_1_8b_chess_lora):
57-
llm = vllm.LLM(
58-
MODEL_NAME,
59-
max_num_seqs=2,
60-
enable_lora=True,
61-
max_loras=4,
62-
max_lora_rank=128,
63-
trust_remote_code=True,
64-
dtype="bfloat16",
65-
max_model_len=4096,
66-
enforce_eager=True
67-
)
68-
output1 = do_sample(llm, llama3_1_8b_chess_lora, lora_id=1)
69-
for i in range(len(EXPECTED_OUTPUT)):
70-
assert EXPECTED_OUTPUT[i].startswith(output1[i])
71-
return None
97+
# verify that text don't match with no lora
98+
assert llama_no_lora_str != ultravox_str

vllm/assets/audio.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,6 @@
2020
class AudioAsset:
2121
name: Literal["winning_call", "mary_had_lamb"]
2222

23-
def __init__(self, audio_path=None):
24-
if audio_path is None:
25-
audio_path = get_vllm_public_assets(filename=f"{self.name}.ogg",
26-
s3_prefix=ASSET_DIR)
27-
28-
object.__setattr__(self, '_audio_path', audio_path)
29-
3023
@property
3124
def audio_and_sample_rate(self) -> tuple[npt.NDArray, float]:
3225
audio_path = get_vllm_public_assets(filename=f"{self.name}.ogg",

vllm/lora/models.py

Lines changed: 10 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -167,14 +167,9 @@ def from_lora_tensors(
167167
loras[module_name].lora_b = loras[
168168
module_name].lora_b.pin_memory()
169169

170-
print_v=False
171170
for lora in loras.values():
172-
if "v_proj" in lora.module_name and not print_v:
173-
print_v=True
174-
logger.debug(f"Size of v_proj is: {lora.lora_a.size()}")
175171
lora.optimize()
176172

177-
logger.debug(f"Creating loras for {lora_model_id} with following modules {loras.keys()}")
178173
return cls(lora_model_id,
179174
peft_helper.r,
180175
loras,
@@ -392,11 +387,10 @@ def activate_adapter(
392387
logger.debug("Activating LoRA. int id: %d, slot index: %d",
393388
lora_model.id, index)
394389
self.lora_index_to_id[index] = lora_model.id
390+
missing_modules = []
395391
for module_name, module in self.modules.items():
396392
module_lora = lora_model.get_lora(module_name)
397393
if module_lora:
398-
logger.debug("Setting LoRA. int id: %d, module: %s",
399-
lora_model.id, module_name)
400394
module_lora.optimize()
401395
# Bias is not explicitly enabled with the flag enable_lora_bias.
402396
bias = module_lora.bias
@@ -412,9 +406,14 @@ def activate_adapter(
412406
module_lora.embeddings_tensor,
413407
module_lora.bias)
414408
else:
415-
logger.debug("Reseting lora. int id: %d, module: %s",
416-
lora_model.id, module_name)
409+
missing_modules.append(module_name)
417410
module.reset_lora(index)
411+
412+
if len(missing_modules) > 0:
413+
logger.warning(
414+
"Lora adapter int id %d is activated but is missing \
415+
base model modules %s which could impact output",
416+
lora_model.id, missing_modules)
418417
return True
419418

420419
def _deactivate_adapter(self, lora_id: int):
@@ -471,10 +470,6 @@ def _create_lora_modules(self):
471470
for module_name, module in self.model.named_modules(
472471
remove_duplicate=False):
473472

474-
logger.debug(
475-
"Create lora module if applicable %s",
476-
module_name,
477-
)
478473
if isinstance(module, PPMissingLayer):
479474
continue
480475
if not self._match_target_modules(module_name):
@@ -521,15 +516,12 @@ def _create_lora_modules(self):
521516
if self.supports_mm and not isinstance(new_module,
522517
BaseLayerWithLoRA):
523518
logger.warning(
524-
"%s module will be ignored because it isn't of type BaseLayerWithLoRA",
519+
"%s module will be ignored because it isn't of type \
520+
BaseLayerWithLoRA",
525521
module_name,
526522
)
527523
continue
528524

529-
logger.debug(
530-
"Going to apply lora on %s module",
531-
module_name,
532-
)
533525
self.register_module(module_name, new_module)
534526
self._register_packed_modules(module_name)
535527
# All lora layers share the same punica_wrapper based on reference.
@@ -545,9 +537,6 @@ def create_dummy_lora(
545537
rank: int,
546538
scaling_factor: Optional[float],
547539
embedding_modules: Optional[Dict[str, str]] = None) -> LoRAModel:
548-
logger.debug(
549-
f"Creating a dummy lora with id: {lora_id}"
550-
)
551540
"""Create zero-initialized LoRAModel for warmup."""
552541
model = LoRAModel(lora_id, rank, {}, scaling_factor)
553542
for module_name, module in self.model.named_modules():

0 commit comments

Comments
 (0)