Skip to content

Commit 1976ee0

Browse files
committed
Test setting cpu as a default device
- Reduce model len and max num seq to reduce memory - Re-trigger tests Signed-off-by: Sumit Vij <sumitvij11+github@gmail.com>
1 parent f483d9a commit 1976ee0

File tree

1 file changed

+22
-12
lines changed

1 file changed

+22
-12
lines changed

tests/lora/test_ultravox.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,15 @@
33
from tempfile import TemporaryDirectory
44
from typing import List, Tuple
55

6+
import torch
67
from huggingface_hub import snapshot_download
78
from safetensors.torch import load_file, save_file
89
from transformers import AutoTokenizer
910

1011
from vllm.lora.request import LoRARequest
1112

13+
from ..models.utils import check_outputs_equal
14+
1215
ULTRAVOX_MODEL_NAME = "fixie-ai/ultravox-v0_3"
1316
LLMA_MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
1417

@@ -62,19 +65,24 @@ def test_ultravox_lora(vllm_runner):
6265
"""
6366
TODO: Train an Ultravox LoRA instead of using a Llama LoRA.
6467
"""
68+
# Workaround to prevent device mismatch in Whisper.
69+
# Can be removed when it is fixed upstream in transformer
70+
# https://github.com/huggingface/transformers/pull/35866
71+
torch.set_default_device("cpu")
72+
6573
llama3_1_8b_chess_lora = llama3_1_8b_chess_lora_path()
6674
with TemporaryDirectory() as temp_ultravox_lora_dir:
6775
llama3_1_8b_ultravox_chess_lora = mk_llama3_1_8b_ultravox_chess_lora(
6876
llama3_1_8b_chess_lora, temp_ultravox_lora_dir)
6977
with vllm_runner(
7078
ULTRAVOX_MODEL_NAME,
7179
enforce_eager=True,
72-
max_num_seqs=128,
80+
max_num_seqs=2,
7381
enable_lora=True,
74-
max_loras=4,
82+
max_loras=1,
7583
max_lora_rank=128,
7684
dtype="bfloat16",
77-
max_model_len=4096,
85+
max_model_len=1024,
7886
) as vllm_model:
7987
ultravox_outputs: List[Tuple[
8088
List[int], str]] = vllm_model.generate_greedy(
@@ -91,21 +99,23 @@ def test_ultravox_lora(vllm_runner):
9199
with vllm_runner(
92100
LLMA_MODEL_NAME,
93101
enforce_eager=True,
94-
max_num_seqs=128,
102+
max_num_seqs=2,
95103
enable_lora=True,
96-
max_loras=4,
104+
max_loras=1,
97105
max_lora_rank=128,
98106
dtype="bfloat16",
99-
max_model_len=4096,
107+
max_model_len=1024,
100108
) as vllm_model:
101-
llama_outputs_no_lora: List[Tuple[List[int], str]] = (
109+
llama_outputs: List[Tuple[List[int], str]] = (
102110
vllm_model.generate_greedy(
103111
[_get_prompt(0, PROMPT, VLLM_PLACEHOLDER, LLMA_MODEL_NAME)],
104112
256,
113+
lora_request=LoRARequest(str(1), 1, llama3_1_8b_chess_lora),
105114
))
106115

107-
_, llama_no_lora_str = llama_outputs_no_lora[0]
108-
_, ultravox_str = ultravox_outputs[0]
109-
110-
# verify that text don't match with no lora
111-
assert llama_no_lora_str != ultravox_str
116+
check_outputs_equal(
117+
outputs_0_lst=ultravox_outputs,
118+
outputs_1_lst=llama_outputs,
119+
name_0="ultravox",
120+
name_1="llama",
121+
)

0 commit comments

Comments
 (0)