Skip to content

Commit d5c5154

Browse files
authored
[Misc] LoRA + Chunked Prefill (#9057)
1 parent 9a93973 commit d5c5154

File tree

12 files changed

+49
-20
lines changed

12 files changed

+49
-20
lines changed

tests/lora/test_chatglm3_tp.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@ def test_chatglm3_lora(chatglm3_lora_files):
5353
max_loras=4,
5454
max_lora_rank=64,
5555
tensor_parallel_size=1,
56-
trust_remote_code=True)
56+
trust_remote_code=True,
57+
enable_chunked_prefill=True)
5758

5859
output1 = do_sample(llm, chatglm3_lora_files, lora_id=1)
5960
for i in range(len(EXPECTED_LORA_OUTPUT)):
@@ -73,7 +74,8 @@ def test_chatglm3_lora_tp4(chatglm3_lora_files):
7374
max_lora_rank=64,
7475
tensor_parallel_size=4,
7576
trust_remote_code=True,
76-
fully_sharded_loras=False)
77+
fully_sharded_loras=False,
78+
enable_chunked_prefill=True)
7779

7880
output1 = do_sample(llm, chatglm3_lora_files, lora_id=1)
7981
for i in range(len(EXPECTED_LORA_OUTPUT)):
@@ -93,7 +95,8 @@ def test_chatglm3_lora_tp4_fully_sharded_loras(chatglm3_lora_files):
9395
max_lora_rank=64,
9496
tensor_parallel_size=4,
9597
trust_remote_code=True,
96-
fully_sharded_loras=True)
98+
fully_sharded_loras=True,
99+
enable_chunked_prefill=True)
97100
output1 = do_sample(llm, chatglm3_lora_files, lora_id=1)
98101
for i in range(len(EXPECTED_LORA_OUTPUT)):
99102
assert output1[i] == EXPECTED_LORA_OUTPUT[i]

tests/lora/test_gemma.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ def test_gemma_lora(gemma_lora_files):
3737
llm = vllm.LLM(MODEL_PATH,
3838
max_model_len=1024,
3939
enable_lora=True,
40-
max_loras=4)
40+
max_loras=4,
41+
enable_chunked_prefill=True)
4142

4243
expected_lora_output = [
4344
"more important than knowledge.\nAuthor: Albert Einstein\n",

tests/lora/test_llama_tp.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,8 @@ def test_llama_lora(sql_lora_files):
7878
enable_lora=True,
7979
max_num_seqs=16,
8080
max_loras=4,
81-
tensor_parallel_size=1)
81+
tensor_parallel_size=1,
82+
enable_chunked_prefill=True)
8283
generate_and_test(llm, sql_lora_files)
8384

8485

@@ -120,6 +121,7 @@ def test_llama_lora_tp4(sql_lora_files):
120121
max_num_seqs=16,
121122
max_loras=4,
122123
tensor_parallel_size=4,
124+
enable_chunked_prefill=True,
123125
)
124126
generate_and_test(llm, sql_lora_files)
125127

@@ -135,6 +137,7 @@ def test_llama_lora_tp4_fully_sharded_loras(sql_lora_files):
135137
max_loras=4,
136138
tensor_parallel_size=4,
137139
fully_sharded_loras=True,
140+
enable_chunked_prefill=True,
138141
)
139142
generate_and_test(llm, sql_lora_files)
140143

@@ -151,5 +154,6 @@ def test_llama_lora_tp4_fully_sharded_enable_bias(sql_lora_files):
151154
tensor_parallel_size=4,
152155
fully_sharded_loras=True,
153156
enable_lora_bias=True,
157+
enable_chunked_prefill=True,
154158
)
155159
generate_and_test(llm, sql_lora_files)

tests/lora/test_long_context.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,8 @@ def lora_llm(long_context_infos):
124124
tensor_parallel_size=4,
125125
# FIXME enable async output processor
126126
disable_async_output_proc=True,
127-
distributed_executor_backend="mp")
127+
distributed_executor_backend="mp",
128+
enable_chunked_prefill=True)
128129
yield llm
129130
del llm
130131

tests/lora/test_minicpmv.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,8 @@ def test_minicpmv_lora(minicpmv_lora_files):
6767
max_loras=4,
6868
max_lora_rank=64,
6969
trust_remote_code=True,
70-
gpu_memory_utilization=0.97 # This model is pretty big for CI gpus
70+
gpu_memory_utilization=0.97, # This model is pretty big for CI gpus
71+
enable_chunked_prefill=True,
7172
)
7273
output1 = do_sample(llm, minicpmv_lora_files, lora_id=1)
7374
for i in range(len(EXPECTED_OUTPUT)):

tests/lora/test_minicpmv_tp.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def test_minicpmv_tp2(minicpmv_lora_files, fully_sharded):
6969
tensor_parallel_size=2,
7070
trust_remote_code=True,
7171
fully_sharded_loras=fully_sharded,
72+
enable_chunked_prefill=True,
7273
)
7374

7475
output_tp = do_sample(llm, minicpmv_lora_files, lora_id=1)
@@ -89,6 +90,7 @@ def test_minicpmv_tp4(minicpmv_lora_files, fully_sharded):
8990
tensor_parallel_size=4,
9091
trust_remote_code=True,
9192
fully_sharded_loras=fully_sharded,
93+
enable_chunked_prefill=True,
9294
)
9395
output_tp = do_sample(llm, minicpmv_lora_files, lora_id=1)
9496
for i in range(len(EXPECTED_OUTPUT)):

tests/lora/test_mixtral.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def test_mixtral_lora(mixtral_lora_files, tp_size):
4747
max_loras=4,
4848
distributed_executor_backend="ray",
4949
tensor_parallel_size=tp_size,
50+
enable_chunked_prefill=True,
5051
)
5152

5253
expected_lora_output = [

tests/lora/test_phi.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@ def test_phi2_lora(phi2_lora_files):
5353
max_model_len=1024,
5454
enable_lora=True,
5555
max_loras=2,
56-
enforce_eager=True)
56+
enforce_eager=True,
57+
enable_chunked_prefill=True)
5758

5859
expected_lora_output = [
5960
"SELECT catalog_publisher, COUNT(*) as num_catalogs FROM catalogs GROUP BY catalog_publisher ORDER BY num_catalogs DESC LIMIT 1;", # noqa: E501

tests/lora/test_quant_model.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,8 @@ def test_quant_model_lora(tinyllama_lora_files, num_gpus_available, model,
8484
tensor_parallel_size=tp_size,
8585
gpu_memory_utilization=0.2, #avoid OOM
8686
quantization=model.quantization,
87-
trust_remote_code=True)
87+
trust_remote_code=True,
88+
enable_chunked_prefill=True)
8889

8990
if model.quantization is None:
9091
expected_no_lora_output = [
@@ -176,7 +177,8 @@ def test_quant_model_tp_equality(tinyllama_lora_files, num_gpus_available,
176177
tensor_parallel_size=1,
177178
gpu_memory_utilization=0.2, #avoid OOM
178179
quantization=model.quantization,
179-
trust_remote_code=True)
180+
trust_remote_code=True,
181+
enable_chunked_prefill=True)
180182
output_tp1 = do_sample(llm_tp1, tinyllama_lora_files, lora_id=1)
181183

182184
del llm_tp1
@@ -189,7 +191,8 @@ def test_quant_model_tp_equality(tinyllama_lora_files, num_gpus_available,
189191
max_loras=4,
190192
tensor_parallel_size=2,
191193
gpu_memory_utilization=0.2, #avoid OOM
192-
quantization=model.quantization)
194+
quantization=model.quantization,
195+
enable_chunked_prefill=True)
193196
output_tp2 = do_sample(llm_tp2, tinyllama_lora_files, lora_id=1)
194197

195198
del llm_tp2

vllm/config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1698,7 +1698,8 @@ def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig):
16981698
# Reminder: Please update docs/source/usage/compatibility_matrix.rst
16991699
# If the feature combo become valid
17001700
if scheduler_config.chunked_prefill_enabled:
1701-
raise ValueError("LoRA is not supported with chunked prefill yet.")
1701+
logger.warning("LoRA with chunked prefill is still experimental "
1702+
"and may be unstable.")
17021703

17031704

17041705
@dataclass

0 commit comments

Comments
 (0)