Skip to content

Commit 32309f7

Browse files
Isotr0pyMu Huai
authored andcommitted
[Bugfix] Fix missing lora name mapping for lora without prefix (vllm-project#17793)
Signed-off-by: Isotr0py <2037008807@qq.com> Signed-off-by: Mu Huai <tianbowen.tbw@antgroup.com>
1 parent 0865de7 commit 32309f7

File tree

2 files changed

+61
-14
lines changed

2 files changed

+61
-14
lines changed

tests/lora/test_utils.py

Lines changed: 57 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
from collections import OrderedDict
4+
from typing import NamedTuple, Optional
45
from unittest.mock import patch
56

67
import pytest
@@ -9,52 +10,96 @@
910

1011
from vllm.lora.utils import (get_adapter_absolute_path,
1112
parse_fine_tuned_lora_name, replace_submodule)
13+
from vllm.model_executor.models.utils import WeightsMapper
14+
15+
16+
class LoRANameParserTestConfig(NamedTuple):
17+
name: str
18+
module_name: str
19+
is_lora_a: bool
20+
is_bias: bool
21+
weights_mapper: Optional[WeightsMapper] = None
1222

1323

1424
def test_parse_fine_tuned_lora_name_valid():
15-
fixture = {
16-
("base_model.model.lm_head.lora_A.weight", "lm_head", True, False),
17-
("base_model.model.lm_head.lora_B.weight", "lm_head", False, False),
18-
(
25+
fixture = [
26+
LoRANameParserTestConfig("base_model.model.lm_head.lora_A.weight",
27+
"lm_head", True, False),
28+
LoRANameParserTestConfig("base_model.model.lm_head.lora_B.weight",
29+
"lm_head", False, False),
30+
LoRANameParserTestConfig(
1931
"base_model.model.model.embed_tokens.lora_embedding_A",
2032
"model.embed_tokens",
2133
True,
2234
False,
2335
),
24-
(
36+
LoRANameParserTestConfig(
2537
"base_model.model.model.embed_tokens.lora_embedding_B",
2638
"model.embed_tokens",
2739
False,
2840
False,
2941
),
30-
(
42+
LoRANameParserTestConfig(
3143
"base_model.model.model.layers.9.mlp.down_proj.lora_A.weight",
3244
"model.layers.9.mlp.down_proj",
3345
True,
3446
False,
3547
),
36-
(
48+
LoRANameParserTestConfig(
3749
"base_model.model.model.layers.9.mlp.down_proj.lora_B.weight",
3850
"model.layers.9.mlp.down_proj",
3951
False,
4052
False,
4153
),
42-
(
54+
LoRANameParserTestConfig(
4355
"language_model.layers.9.mlp.down_proj.lora_A.weight",
4456
"language_model.layers.9.mlp.down_proj",
4557
True,
4658
False,
4759
),
48-
(
60+
LoRANameParserTestConfig(
4961
"language_model.layers.9.mlp.down_proj.lora_B.weight",
5062
"language_model.layers.9.mlp.down_proj",
5163
False,
5264
False,
5365
),
54-
}
55-
for name, module_name, is_lora_a, is_bias in fixture:
66+
# Test with WeightsMapper
67+
LoRANameParserTestConfig(
68+
"base_model.model.model.layers.9.mlp.down_proj.lora_A.weight",
69+
"language_model.model.layers.9.mlp.down_proj",
70+
True,
71+
False,
72+
weights_mapper=WeightsMapper(
73+
orig_to_new_prefix={"model.": "language_model.model."}),
74+
),
75+
LoRANameParserTestConfig(
76+
"base_model.model.model.layers.9.mlp.down_proj.lora_B.weight",
77+
"language_model.model.layers.9.mlp.down_proj",
78+
False,
79+
False,
80+
weights_mapper=WeightsMapper(
81+
orig_to_new_prefix={"model.": "language_model.model."}),
82+
),
83+
LoRANameParserTestConfig(
84+
"model.layers.9.mlp.down_proj.lora_A.weight",
85+
"language_model.model.layers.9.mlp.down_proj",
86+
True,
87+
False,
88+
weights_mapper=WeightsMapper(
89+
orig_to_new_prefix={"model.": "language_model.model."}),
90+
),
91+
LoRANameParserTestConfig(
92+
"model.layers.9.mlp.down_proj.lora_B.weight",
93+
"language_model.model.layers.9.mlp.down_proj",
94+
False,
95+
False,
96+
weights_mapper=WeightsMapper(
97+
orig_to_new_prefix={"model.": "language_model.model."}),
98+
),
99+
]
100+
for name, module_name, is_lora_a, is_bias, weights_mapper in fixture:
56101
assert (module_name, is_lora_a,
57-
is_bias) == parse_fine_tuned_lora_name(name)
102+
is_bias) == parse_fine_tuned_lora_name(name, weights_mapper)
58103

59104

60105
def test_parse_fine_tuned_lora_name_invalid():

vllm/lora/utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,16 +117,18 @@ def parse_fine_tuned_lora_name(
117117
# LoRA weight qualified name usually starts with `base_model.model.`,
118118
# so we remove the prefix `base_model.model.` to make the following
119119
# mapping correctly.
120-
if "base_model.model." in name:
120+
if name.startswith("base_model.model."):
121121
name = name.replace("base_model.model.", "")
122122
name = weights_mapper._map_name(name) if weights_mapper else name
123123
# recover the prefix `base_model.model.`
124124
name = "base_model.model." + name
125+
else:
126+
name = weights_mapper._map_name(name) if weights_mapper else name
125127

126128
# In some situations, we may not start with `base_model.model.`.
127129
# If we don't (e.g., ibm-granite/granite-speech-3.3-8b),
128130
# we should keep the prefix intact.
129-
start_index = 2 if "base_model.model." in name else 0
131+
start_index = 2 if name.startswith("base_model.model.") else 0
130132

131133
parts = name.split(".")
132134
if parts[-1] == "weight" and (parts[-2] == "lora_A"

0 commit comments

Comments
 (0)