3
3
from tempfile import TemporaryDirectory
4
4
from typing import List , Tuple
5
5
6
+ import torch
6
7
from huggingface_hub import snapshot_download
7
8
from safetensors .torch import load_file , save_file
8
9
from transformers import AutoTokenizer
9
10
10
11
from vllm .lora .request import LoRARequest
11
12
13
+ from ..models .utils import check_outputs_equal
14
+
12
15
ULTRAVOX_MODEL_NAME = "fixie-ai/ultravox-v0_3"
13
16
LLMA_MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
14
17
@@ -62,19 +65,24 @@ def test_ultravox_lora(vllm_runner):
62
65
"""
63
66
TODO: Train an Ultravox LoRA instead of using a Llama LoRA.
64
67
"""
68
+ # Workaround to prevent device mismatch in Whisper.
69
+ # Can be removed when it is fixed upstream in transformer
70
+ # https://github.com/huggingface/transformers/pull/35866
71
+ torch .set_default_device ("cpu" )
72
+
65
73
llama3_1_8b_chess_lora = llama3_1_8b_chess_lora_path ()
66
74
with TemporaryDirectory () as temp_ultravox_lora_dir :
67
75
llama3_1_8b_ultravox_chess_lora = mk_llama3_1_8b_ultravox_chess_lora (
68
76
llama3_1_8b_chess_lora , temp_ultravox_lora_dir )
69
77
with vllm_runner (
70
78
ULTRAVOX_MODEL_NAME ,
71
79
enforce_eager = True ,
72
- max_num_seqs = 128 ,
80
+ max_num_seqs = 2 ,
73
81
enable_lora = True ,
74
- max_loras = 4 ,
82
+ max_loras = 1 ,
75
83
max_lora_rank = 128 ,
76
84
dtype = "bfloat16" ,
77
- max_model_len = 4096 ,
85
+ max_model_len = 1024 ,
78
86
) as vllm_model :
79
87
ultravox_outputs : List [Tuple [
80
88
List [int ], str ]] = vllm_model .generate_greedy (
@@ -91,21 +99,23 @@ def test_ultravox_lora(vllm_runner):
91
99
with vllm_runner (
92
100
LLMA_MODEL_NAME ,
93
101
enforce_eager = True ,
94
- max_num_seqs = 128 ,
102
+ max_num_seqs = 2 ,
95
103
enable_lora = True ,
96
- max_loras = 4 ,
104
+ max_loras = 1 ,
97
105
max_lora_rank = 128 ,
98
106
dtype = "bfloat16" ,
99
- max_model_len = 4096 ,
107
+ max_model_len = 1024 ,
100
108
) as vllm_model :
101
- llama_outputs_no_lora : List [Tuple [List [int ], str ]] = (
109
+ llama_outputs : List [Tuple [List [int ], str ]] = (
102
110
vllm_model .generate_greedy (
103
111
[_get_prompt (0 , PROMPT , VLLM_PLACEHOLDER , LLMA_MODEL_NAME )],
104
112
256 ,
113
+ lora_request = LoRARequest (str (1 ), 1 , llama3_1_8b_chess_lora ),
105
114
))
106
115
107
- _ , llama_no_lora_str = llama_outputs_no_lora [0 ]
108
- _ , ultravox_str = ultravox_outputs [0 ]
109
-
110
- # verify that text don't match with no lora
111
- assert llama_no_lora_str != ultravox_str
116
+ check_outputs_equal (
117
+ outputs_0_lst = ultravox_outputs ,
118
+ outputs_1_lst = llama_outputs ,
119
+ name_0 = "ultravox" ,
120
+ name_1 = "llama" ,
121
+ )
0 commit comments