Skip to content
This repository was archived by the owner on Oct 11, 2024. It is now read-only.

Commit 1b41d11

Browse files
DarkLight1337Robert Shaw
authored andcommitted
[Misc] Improve error message when LoRA parsing fails (vllm-project#5194)
1 parent b21be06 commit 1b41d11

File tree

2 files changed

+20
-9
lines changed

2 files changed

+20
-9
lines changed

tests/lora/test_utils.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
from collections import OrderedDict
22

3+
import pytest
34
from torch import nn
45

56
from vllm.lora.utils import parse_fine_tuned_lora_name, replace_submodule
67
from vllm.utils import LRUCache
78

89

9-
def test_parse_fine_tuned_lora_name():
10+
def test_parse_fine_tuned_lora_name_valid():
1011
fixture = {
1112
("base_model.model.lm_head.lora_A.weight", "lm_head", True),
1213
("base_model.model.lm_head.lora_B.weight", "lm_head", False),
@@ -35,6 +36,17 @@ def test_parse_fine_tuned_lora_name():
3536
assert (module_name, is_lora_a) == parse_fine_tuned_lora_name(name)
3637

3738

39+
def test_parse_fine_tuned_lora_name_invalid():
40+
fixture = {
41+
"weight",
42+
"base_model.weight",
43+
"base_model.model.weight",
44+
}
45+
for name in fixture:
46+
with pytest.raises(ValueError, match="unsupported LoRA weight"):
47+
parse_fine_tuned_lora_name(name)
48+
49+
3850
def test_replace_submodule():
3951
model = nn.Sequential(
4052
OrderedDict([

vllm/lora/utils.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -94,13 +94,12 @@ def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool]:
9494
is_lora_a whether the tensor is lora_a or lora_b.
9595
"""
9696
parts = name.split(".")
97-
assert parts[0] == "base_model"
98-
assert parts[1] == "model"
99-
if parts[-1] == "weight":
100-
assert parts[-2] == "lora_A" or parts[-2] == "lora_B"
101-
return ".".join(parts[2:-2]), parts[-2] == "lora_A"
10297

103-
if parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B":
104-
return ".".join(parts[2:-1]), parts[-1] == "lora_embedding_A"
98+
if len(parts) >= 2 and parts[0] == "base_model" and parts[1] == "model":
99+
if parts[-1] == "weight":
100+
if parts[-2] == "lora_A" or parts[-2] == "lora_B":
101+
return ".".join(parts[2:-2]), parts[-2] == "lora_A"
102+
elif parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B":
103+
return ".".join(parts[2:-1]), parts[-1] == "lora_embedding_A"
105104

106-
raise ValueError(f"{name} is unsupported format")
105+
raise ValueError(f"{name} is unsupported LoRA weight")

0 commit comments

Comments
 (0)