This repository was archived by the owner on Oct 11, 2024. It is now read-only.
File tree Expand file tree Collapse file tree 2 files changed +20
-9
lines changed Expand file tree Collapse file tree 2 files changed +20
-9
lines changed Original file line number Diff line number Diff line change 1
1
from collections import OrderedDict
2
2
3
+ import pytest
3
4
from torch import nn
4
5
5
6
from vllm .lora .utils import parse_fine_tuned_lora_name , replace_submodule
6
7
from vllm .utils import LRUCache
7
8
8
9
9
- def test_parse_fine_tuned_lora_name ():
10
+ def test_parse_fine_tuned_lora_name_valid ():
10
11
fixture = {
11
12
("base_model.model.lm_head.lora_A.weight" , "lm_head" , True ),
12
13
("base_model.model.lm_head.lora_B.weight" , "lm_head" , False ),
@@ -35,6 +36,17 @@ def test_parse_fine_tuned_lora_name():
35
36
assert (module_name , is_lora_a ) == parse_fine_tuned_lora_name (name )
36
37
37
38
39
+ def test_parse_fine_tuned_lora_name_invalid ():
40
+ fixture = {
41
+ "weight" ,
42
+ "base_model.weight" ,
43
+ "base_model.model.weight" ,
44
+ }
45
+ for name in fixture :
46
+ with pytest .raises (ValueError , match = "unsupported LoRA weight" ):
47
+ parse_fine_tuned_lora_name (name )
48
+
49
+
38
50
def test_replace_submodule ():
39
51
model = nn .Sequential (
40
52
OrderedDict ([
Original file line number Diff line number Diff line change @@ -94,13 +94,12 @@ def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool]:
94
94
is_lora_a whether the tensor is lora_a or lora_b.
95
95
"""
96
96
parts = name .split ("." )
97
- assert parts [0 ] == "base_model"
98
- assert parts [1 ] == "model"
99
- if parts [- 1 ] == "weight" :
100
- assert parts [- 2 ] == "lora_A" or parts [- 2 ] == "lora_B"
101
- return "." .join (parts [2 :- 2 ]), parts [- 2 ] == "lora_A"
102
97
103
- if parts [- 1 ] == "lora_embedding_A" or parts [- 1 ] == "lora_embedding_B" :
104
- return "." .join (parts [2 :- 1 ]), parts [- 1 ] == "lora_embedding_A"
98
+ if len (parts ) >= 2 and parts [0 ] == "base_model" and parts [1 ] == "model" :
99
+ if parts [- 1 ] == "weight" :
100
+ if parts [- 2 ] == "lora_A" or parts [- 2 ] == "lora_B" :
101
+ return "." .join (parts [2 :- 2 ]), parts [- 2 ] == "lora_A"
102
+ elif parts [- 1 ] == "lora_embedding_A" or parts [- 1 ] == "lora_embedding_B" :
103
+ return "." .join (parts [2 :- 1 ]), parts [- 1 ] == "lora_embedding_A"
105
104
106
- raise ValueError (f"{ name } is unsupported format " )
105
+ raise ValueError (f"{ name } is unsupported LoRA weight " )
You can’t perform that action at this time.
0 commit comments