|
1 | 1 | # SPDX-License-Identifier: Apache-2.0
|
2 | 2 |
|
3 | 3 | from collections import OrderedDict
|
| 4 | +from typing import NamedTuple, Optional |
4 | 5 | from unittest.mock import patch
|
5 | 6 |
|
6 | 7 | import pytest
|
|
9 | 10 |
|
10 | 11 | from vllm.lora.utils import (get_adapter_absolute_path,
|
11 | 12 | parse_fine_tuned_lora_name, replace_submodule)
|
| 13 | +from vllm.model_executor.models.utils import WeightsMapper |
| 14 | + |
| 15 | + |
| 16 | +class LoRANameParserTestConfig(NamedTuple): |
| 17 | + name: str |
| 18 | + module_name: str |
| 19 | + is_lora_a: bool |
| 20 | + is_bias: bool |
| 21 | + weights_mapper: Optional[WeightsMapper] = None |
12 | 22 |
|
13 | 23 |
|
14 | 24 | def test_parse_fine_tuned_lora_name_valid():
|
15 |
| - fixture = { |
16 |
| - ("base_model.model.lm_head.lora_A.weight", "lm_head", True, False), |
17 |
| - ("base_model.model.lm_head.lora_B.weight", "lm_head", False, False), |
18 |
| - ( |
| 25 | + fixture = [ |
| 26 | + LoRANameParserTestConfig("base_model.model.lm_head.lora_A.weight", |
| 27 | + "lm_head", True, False), |
| 28 | + LoRANameParserTestConfig("base_model.model.lm_head.lora_B.weight", |
| 29 | + "lm_head", False, False), |
| 30 | + LoRANameParserTestConfig( |
19 | 31 | "base_model.model.model.embed_tokens.lora_embedding_A",
|
20 | 32 | "model.embed_tokens",
|
21 | 33 | True,
|
22 | 34 | False,
|
23 | 35 | ),
|
24 |
| - ( |
| 36 | + LoRANameParserTestConfig( |
25 | 37 | "base_model.model.model.embed_tokens.lora_embedding_B",
|
26 | 38 | "model.embed_tokens",
|
27 | 39 | False,
|
28 | 40 | False,
|
29 | 41 | ),
|
30 |
| - ( |
| 42 | + LoRANameParserTestConfig( |
31 | 43 | "base_model.model.model.layers.9.mlp.down_proj.lora_A.weight",
|
32 | 44 | "model.layers.9.mlp.down_proj",
|
33 | 45 | True,
|
34 | 46 | False,
|
35 | 47 | ),
|
36 |
| - ( |
| 48 | + LoRANameParserTestConfig( |
37 | 49 | "base_model.model.model.layers.9.mlp.down_proj.lora_B.weight",
|
38 | 50 | "model.layers.9.mlp.down_proj",
|
39 | 51 | False,
|
40 | 52 | False,
|
41 | 53 | ),
|
42 |
| - ( |
| 54 | + LoRANameParserTestConfig( |
43 | 55 | "language_model.layers.9.mlp.down_proj.lora_A.weight",
|
44 | 56 | "language_model.layers.9.mlp.down_proj",
|
45 | 57 | True,
|
46 | 58 | False,
|
47 | 59 | ),
|
48 |
| - ( |
| 60 | + LoRANameParserTestConfig( |
49 | 61 | "language_model.layers.9.mlp.down_proj.lora_B.weight",
|
50 | 62 | "language_model.layers.9.mlp.down_proj",
|
51 | 63 | False,
|
52 | 64 | False,
|
53 | 65 | ),
|
54 |
| - } |
55 |
| - for name, module_name, is_lora_a, is_bias in fixture: |
| 66 | + # Test with WeightsMapper |
| 67 | + LoRANameParserTestConfig( |
| 68 | + "base_model.model.model.layers.9.mlp.down_proj.lora_A.weight", |
| 69 | + "language_model.model.layers.9.mlp.down_proj", |
| 70 | + True, |
| 71 | + False, |
| 72 | + weights_mapper=WeightsMapper( |
| 73 | + orig_to_new_prefix={"model.": "language_model.model."}), |
| 74 | + ), |
| 75 | + LoRANameParserTestConfig( |
| 76 | + "base_model.model.model.layers.9.mlp.down_proj.lora_B.weight", |
| 77 | + "language_model.model.layers.9.mlp.down_proj", |
| 78 | + False, |
| 79 | + False, |
| 80 | + weights_mapper=WeightsMapper( |
| 81 | + orig_to_new_prefix={"model.": "language_model.model."}), |
| 82 | + ), |
| 83 | + LoRANameParserTestConfig( |
| 84 | + "model.layers.9.mlp.down_proj.lora_A.weight", |
| 85 | + "language_model.model.layers.9.mlp.down_proj", |
| 86 | + True, |
| 87 | + False, |
| 88 | + weights_mapper=WeightsMapper( |
| 89 | + orig_to_new_prefix={"model.": "language_model.model."}), |
| 90 | + ), |
| 91 | + LoRANameParserTestConfig( |
| 92 | + "model.layers.9.mlp.down_proj.lora_B.weight", |
| 93 | + "language_model.model.layers.9.mlp.down_proj", |
| 94 | + False, |
| 95 | + False, |
| 96 | + weights_mapper=WeightsMapper( |
| 97 | + orig_to_new_prefix={"model.": "language_model.model."}), |
| 98 | + ), |
| 99 | + ] |
| 100 | + for name, module_name, is_lora_a, is_bias, weights_mapper in fixture: |
56 | 101 | assert (module_name, is_lora_a,
|
57 |
| - is_bias) == parse_fine_tuned_lora_name(name) |
| 102 | + is_bias) == parse_fine_tuned_lora_name(name, weights_mapper) |
58 | 103 |
|
59 | 104 |
|
60 | 105 | def test_parse_fine_tuned_lora_name_invalid():
|
|
0 commit comments