Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Commit

Permalink
[Bugfix] Fix LoRA loading check (vllm-project#4138)
Browse files Browse the repository at this point in the history
Co-authored-by: simon-mo <simon.mo@hey.com>
  • Loading branch information
2 people authored and robertgshaw2-neuralmagic committed Apr 26, 2024
1 parent e726a89 commit 643d8d1
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 3 deletions.
6 changes: 6 additions & 0 deletions tests/lora/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,12 @@ def baichuan_lora_files():
return snapshot_download(repo_id="jeeejeee/baichuan7b-text2sql-spider")


@pytest.fixture(scope="session")
def baichuan_zero_lora_files():
# all the lora_B weights are initialized to zero.
return snapshot_download(repo_id="jeeejeee/baichuan7b-zero-init")


@pytest.fixture(scope="session")
def tinyllama_lora_files():
return snapshot_download(repo_id="jashing/tinyllama-colorist-lora")
Expand Down
22 changes: 20 additions & 2 deletions tests/lora/test_lora_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,16 @@
from vllm.lora.models import LoRAModel
from vllm.model_executor.models.baichuan import BaiChuanBaseForCausalLM

lora_lst = ["baichuan7B", "baichuan7B-zero", "chatglm3-6b"]

@pytest.mark.parametrize("lora_name", ["baichuan7B", "chatglm3-6b"])
def test_load_checkpoints(lora_name, chatglm3_lora_files, baichuan_lora_files):

@pytest.mark.parametrize("lora_name", lora_lst)
def test_load_checkpoints(
lora_name,
baichuan_lora_files,
baichuan_zero_lora_files,
chatglm3_lora_files,
):
supported_lora_modules = BaiChuanBaseForCausalLM.supported_lora_modules
packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping
embedding_modules = BaiChuanBaseForCausalLM.embedding_modules
Expand All @@ -26,6 +33,17 @@ def test_load_checkpoints(lora_name, chatglm3_lora_files, baichuan_lora_files):
device="cpu",
embedding_modules=embedding_modules,
embedding_padding_modules=embed_padding_modules)
elif lora_name == "baichuan7B-zero":
#Test that the target_modules contain prefix
# such as "model.layers.0.self_atten.W_pack", and
# the test should pass.
LoRAModel.from_local_checkpoint(
baichuan_zero_lora_files,
expected_lora_modules,
lora_model_id=1,
device="cpu",
embedding_modules=embedding_modules,
embedding_padding_modules=embed_padding_modules)
else:
# For the baichuan7B model, load chatglm3-6b's LoRA,
# and the test should raise the following error.
Expand Down
4 changes: 3 additions & 1 deletion vllm/lora/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,9 @@ def from_local_checkpoint(
target_modules = config["target_modules"]
unexpected_modules = []
for module in target_modules:
if module not in expected_lora_modules:
# Compatible with more modules, such as:layers.11.self_attn.k_proj
part_name = module.split(".")[-1]
if part_name not in expected_lora_modules:
unexpected_modules.append(module)
# loaded lora's target modules must be a subset of expected_lora_modules
if unexpected_modules:
Expand Down

0 comments on commit 643d8d1

Please sign in to comment.