Skip to content

Commit 2672284

Browse files
jeejeeleerasmith
authored andcommitted
[Misc][LoRA] Improve the readability of LoRA error messages (vllm-project#12102)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
1 parent c2a8acf commit 2672284

File tree

10 files changed

+245
-116
lines changed

10 files changed

+245
-116
lines changed

tests/entrypoints/openai/test_lora_adapters.py

Lines changed: 50 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,33 @@
1717
# generation quality here
1818
LORA_NAME = "typeof/zephyr-7b-beta-lora"
1919

20+
BADREQUEST_CASES = [
21+
(
22+
"test_rank",
23+
{
24+
"r": 1024
25+
},
26+
"is greater than max_lora_rank",
27+
),
28+
(
29+
"test_bias",
30+
{
31+
"bias": "all"
32+
},
33+
"Adapter bias cannot be used without bias_enabled",
34+
),
35+
("test_dora", {
36+
"use_dora": True
37+
}, "does not yet support DoRA"),
38+
(
39+
"test_modules_to_save",
40+
{
41+
"modules_to_save": ["lm_head"]
42+
},
43+
"only supports modules_to_save being None",
44+
),
45+
]
46+
2047

2148
@pytest.fixture(scope="module")
2249
def zephyr_lora_files():
@@ -138,32 +165,36 @@ async def test_dynamic_lora_invalid_files(client: openai.AsyncOpenAI,
138165

139166

140167
@pytest.mark.asyncio
141-
async def test_dynamic_lora_invalid_lora_rank(client: openai.AsyncOpenAI,
142-
tmp_path, zephyr_lora_files):
143-
invalid_rank = tmp_path / "invalid_rank"
144-
145-
# Copy adapter from zephyr_lora_files to invalid_rank
146-
shutil.copytree(zephyr_lora_files, invalid_rank)
147-
148-
with open(invalid_rank / "adapter_config.json") as f:
168+
@pytest.mark.parametrize("test_name,config_change,expected_error",
169+
BADREQUEST_CASES)
170+
async def test_dynamic_lora_badrequests(client: openai.AsyncOpenAI, tmp_path,
171+
zephyr_lora_files, test_name: str,
172+
config_change: dict,
173+
expected_error: str):
174+
# Create test directory
175+
test_dir = tmp_path / test_name
176+
177+
# Copy adapter files
178+
shutil.copytree(zephyr_lora_files, test_dir)
179+
180+
# Load and modify configuration
181+
config_path = test_dir / "adapter_config.json"
182+
with open(config_path) as f:
149183
adapter_config = json.load(f)
184+
# Apply configuration changes
185+
adapter_config.update(config_change)
150186

151-
print(adapter_config)
152-
153-
# assert False
154-
155-
# Change rank to invalid value
156-
adapter_config["r"] = 1024
157-
with open(invalid_rank / "adapter_config.json", "w") as f:
187+
# Save modified configuration
188+
with open(config_path, "w") as f:
158189
json.dump(adapter_config, f)
159190

160-
with pytest.raises(openai.BadRequestError,
161-
match="is greater than max_lora_rank"):
191+
# Test loading the adapter
192+
with pytest.raises(openai.BadRequestError, match=expected_error):
162193
await client.post("load_lora_adapter",
163194
cast_to=str,
164195
body={
165-
"lora_name": "invalid-json",
166-
"lora_path": str(invalid_rank)
196+
"lora_name": test_name,
197+
"lora_path": str(test_dir)
167198
})
168199

169200

tests/lora/test_lora_checkpoints.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import pytest
44

55
from vllm.lora.models import LoRAModel
6+
from vllm.lora.peft_helper import PEFTHelper
67
from vllm.model_executor.models.baichuan import BaiChuanBaseForCausalLM
78
from vllm.model_executor.models.utils import WeightsMapper
89

@@ -30,11 +31,14 @@ def test_load_checkpoints(
3031
else:
3132
expected_lora_modules.append(module)
3233
if lora_name == "baichuan7B":
34+
peft_helper = PEFTHelper.from_local_dir(baichuan_lora_files,
35+
max_position_embeddings=4096)
3336
# For the baichuan7B model, load it's LoRA,
3437
# and the test should pass.
3538
LoRAModel.from_local_checkpoint(
3639
baichuan_lora_files,
3740
expected_lora_modules,
41+
peft_helper=peft_helper,
3842
lora_model_id=1,
3943
device="cpu",
4044
embedding_modules=embedding_modules,
@@ -43,19 +47,25 @@ def test_load_checkpoints(
4347
# Test that the target_modules contain prefix
4448
# such as "model.layers.0.self_atten.W_pack", and
4549
# the test should pass.
50+
peft_helper = PEFTHelper.from_local_dir(baichuan_zero_lora_files,
51+
max_position_embeddings=4096)
4652
LoRAModel.from_local_checkpoint(
4753
baichuan_zero_lora_files,
4854
expected_lora_modules,
55+
peft_helper=peft_helper,
4956
lora_model_id=1,
5057
device="cpu",
5158
embedding_modules=embedding_modules,
5259
embedding_padding_modules=embed_padding_modules)
5360
elif lora_name == "baichuan7B-zero-regex":
5461
# Test that the `target_modules` in the form of regular expressions,
5562
# such as `model\\..*(W_pack|o_proj)`, and the test should pass.
63+
peft_helper = PEFTHelper.from_local_dir(baichuan_regex_lora_files,
64+
max_position_embeddings=4096)
5665
LoRAModel.from_local_checkpoint(
5766
baichuan_regex_lora_files,
5867
expected_lora_modules,
68+
peft_helper=peft_helper,
5969
lora_model_id=1,
6070
device="cpu",
6171
embedding_modules=embedding_modules,
@@ -64,10 +74,13 @@ def test_load_checkpoints(
6474
# For the baichuan7B model, load chatglm3-6b's LoRA,
6575
# and the test should raise the following error.
6676
expected_error = "Please verify that the loaded LoRA module is correct" # noqa: E501
77+
peft_helper = PEFTHelper.from_local_dir(chatglm3_lora_files,
78+
max_position_embeddings=4096)
6779
with pytest.raises(ValueError, match=expected_error):
6880
LoRAModel.from_local_checkpoint(
6981
chatglm3_lora_files,
7082
expected_lora_modules,
83+
peft_helper=peft_helper,
7184
lora_model_id=1,
7285
device="cpu",
7386
embedding_modules=embedding_modules,
@@ -94,9 +107,12 @@ def test_lora_weights_mapping(baichuan_lora_files):
94107
".layers.": ".baichuan_layers.",
95108
},
96109
)
110+
peft_helper = PEFTHelper.from_local_dir(baichuan_lora_files,
111+
max_position_embeddings=4096)
97112
lora_model = LoRAModel.from_local_checkpoint(
98113
baichuan_lora_files,
99114
expected_lora_modules,
115+
peft_helper=peft_helper,
100116
lora_model_id=1,
101117
device="cpu",
102118
embedding_modules=embedding_modules,

tests/lora/test_lora_huggingface.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import pytest
44

55
from vllm.lora.models import LoRAModel
6+
from vllm.lora.peft_helper import PEFTHelper
67
from vllm.lora.utils import get_adapter_absolute_path
78
from vllm.model_executor.models.llama import LlamaForCausalLM
89

@@ -27,9 +28,11 @@ def test_load_checkpoints_from_huggingface(lora_fixture_name, request):
2728
lora_path = get_adapter_absolute_path(lora_name)
2829

2930
# lora loading should work for either absolute path and hugggingface id.
31+
peft_helper = PEFTHelper.from_local_dir(lora_path, 4096)
3032
lora_model = LoRAModel.from_local_checkpoint(
3133
lora_path,
3234
expected_lora_modules,
35+
peft_helper=peft_helper,
3336
lora_model_id=1,
3437
device="cpu",
3538
embedding_modules=embedding_modules,

tests/lora/test_lora_manager.py

Lines changed: 2 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import json
2-
import math
31
import os
42
from typing import Dict, List
53

@@ -34,68 +32,15 @@
3432
] if current_platform.is_cuda_alike() else ["cpu"])
3533

3634

37-
def test_peft_helper(sql_lora_files):
38-
lora_config_path = os.path.join(sql_lora_files, "adapter_config.json")
39-
with open(lora_config_path) as f:
40-
config = json.load(f)
41-
peft_helper = PEFTHelper.from_dict(config)
42-
assert peft_helper.r == 8
43-
assert peft_helper.lora_alpha == 16
44-
assert peft_helper.target_modules == [
45-
"q_proj",
46-
"v_proj",
47-
"k_proj",
48-
"o_proj",
49-
"gate_proj",
50-
"up_proj",
51-
"down_proj",
52-
"embed_tokens",
53-
"lm_head",
54-
]
55-
scaling = peft_helper.lora_alpha / peft_helper.r
56-
assert abs(peft_helper.vllm_lora_scaling_factor - scaling) < 1e-3
57-
58-
# test RSLoRA
59-
config = dict(r=8,
60-
lora_alpha=16,
61-
target_modules=["gate_proj"],
62-
use_rslora=True)
63-
peft_helper = PEFTHelper.from_dict(config)
64-
65-
scaling = peft_helper.lora_alpha / math.sqrt(peft_helper.r)
66-
assert abs(peft_helper.vllm_lora_scaling_factor - scaling) < 1e-3
67-
68-
expected_error = "vLLM only supports modules_to_save being None."
69-
with pytest.raises(ValueError, match=expected_error):
70-
config = dict(
71-
r=8,
72-
lora_alpha=16,
73-
target_modules=["gate_proj"],
74-
modules_to_save=["lm_head"],
75-
)
76-
PEFTHelper.from_dict(config)
77-
78-
expected_error = "vLLM does not yet support DoRA."
79-
with pytest.raises(ValueError, match=expected_error):
80-
config = dict(r=8,
81-
lora_alpha=16,
82-
target_modules=["gate_proj"],
83-
use_dora=True)
84-
PEFTHelper.from_dict(config)
85-
86-
8735
@pytest.mark.parametrize("device", DEVICES)
8836
def test_from_lora_tensors(sql_lora_files, device):
8937
tensors = load_file(
9038
os.path.join(sql_lora_files, "adapter_model.safetensors"))
9139
new_embeddings = load_file(
9240
os.path.join(sql_lora_files, "new_embeddings.safetensors"))
9341

94-
lora_config_path = os.path.join(sql_lora_files, "adapter_config.json")
95-
with open(lora_config_path) as f:
96-
config = json.load(f)
97-
98-
peft_helper = PEFTHelper.from_dict(config)
42+
peft_helper = PEFTHelper.from_local_dir(sql_lora_files,
43+
max_position_embeddings=4096)
9944
lora_model = LoRAModel.from_lora_tensors(
10045
1,
10146
tensors,

tests/lora/test_peft_helper.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
import json
2+
import math
3+
import shutil
4+
5+
import pytest
6+
7+
from vllm.config import LoRAConfig
8+
from vllm.lora.peft_helper import PEFTHelper
9+
10+
ERROR_CASES = [
11+
(
12+
"test_rank",
13+
{
14+
"r": 1024
15+
},
16+
"is greater than max_lora_rank",
17+
),
18+
(
19+
"test_bias",
20+
{
21+
"bias": "all"
22+
},
23+
"Adapter bias cannot be used without bias_enabled",
24+
),
25+
("test_dora", {
26+
"use_dora": True
27+
}, "does not yet support DoRA"),
28+
(
29+
"test_modules_to_save",
30+
{
31+
"modules_to_save": ["lm_head"]
32+
},
33+
"only supports modules_to_save being None",
34+
),
35+
]
36+
37+
38+
def test_peft_helper_pass(long_context_lora_files_16k_1, tmp_path):
39+
peft_helper = PEFTHelper.from_local_dir(long_context_lora_files_16k_1,
40+
max_position_embeddings=4096)
41+
lora_config = LoRAConfig(max_lora_rank=16, max_cpu_loras=3, max_loras=2)
42+
peft_helper.validate_legal(lora_config)
43+
assert peft_helper.r == 8
44+
assert peft_helper.lora_alpha == 16
45+
assert peft_helper.target_modules == [
46+
"q_proj",
47+
"v_proj",
48+
"k_proj",
49+
"o_proj",
50+
"gate_proj",
51+
"up_proj",
52+
"down_proj",
53+
"embed_tokens",
54+
"lm_head",
55+
]
56+
assert peft_helper.context_length == 16384
57+
assert peft_helper.vllm_max_position_embeddings == 4096
58+
assert peft_helper.vllm_long_context_scaling_factor == float(
59+
math.ceil(peft_helper.context_length /
60+
peft_helper.vllm_max_position_embeddings))
61+
# test RSLoRA
62+
rslora_config = dict(use_rslora=True)
63+
test_dir = tmp_path / "test_rslora"
64+
shutil.copytree(long_context_lora_files_16k_1, test_dir)
65+
66+
# Load and modify configuration
67+
config_path = test_dir / "adapter_config.json"
68+
with open(config_path) as f:
69+
adapter_config = json.load(f)
70+
# Apply configuration changes
71+
adapter_config.update(rslora_config)
72+
73+
# Save modified configuration
74+
with open(config_path, "w") as f:
75+
json.dump(adapter_config, f)
76+
77+
peft_helper = PEFTHelper.from_local_dir(test_dir,
78+
max_position_embeddings=4096)
79+
peft_helper.validate_legal(lora_config)
80+
scaling = peft_helper.lora_alpha / math.sqrt(peft_helper.r)
81+
assert abs(peft_helper.vllm_lora_scaling_factor - scaling) < 1e-3
82+
83+
84+
@pytest.mark.parametrize("test_name,config_change,expected_error", ERROR_CASES)
85+
def test_peft_helper_error(
86+
sql_lora_files,
87+
tmp_path,
88+
test_name: str,
89+
config_change: dict,
90+
expected_error: str,
91+
):
92+
test_dir = tmp_path / test_name
93+
shutil.copytree(sql_lora_files, test_dir)
94+
95+
# Load and modify configuration
96+
config_path = test_dir / "adapter_config.json"
97+
with open(config_path) as f:
98+
adapter_config = json.load(f)
99+
# Apply configuration changes
100+
adapter_config.update(config_change)
101+
102+
# Save modified configuration
103+
with open(config_path, "w") as f:
104+
json.dump(adapter_config, f)
105+
lora_config = LoRAConfig(max_lora_rank=16, max_cpu_loras=3, max_loras=2)
106+
# Test loading the adapter
107+
with pytest.raises(ValueError, match=expected_error):
108+
PEFTHelper.from_local_dir(
109+
test_dir, max_position_embeddings=4096).validate_legal(lora_config)

vllm/engine/multiprocessing/engine.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,7 @@ def _handle_load_adapter_request(self, request: RPCLoadAdapterRequest):
296296
is_engine_errored=False,
297297
exception=e)
298298
self._send_outputs(rpc_err)
299+
return
299300
# Otherwise, send back the successful load message
300301
self._send_outputs(
301302
RPCAdapterLoadedResponse(request_id=request.request_id))

0 commit comments

Comments
 (0)