Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Commit

Permalink
[Lora] Use safetensor keys instead of adapter_config.json to find une…
Browse files Browse the repository at this point in the history
…xpected modules. (vllm-project#5909)

Co-authored-by: sang <sangcho@anyscale.com>
  • Loading branch information
2 people authored and robertgshaw2-neuralmagic committed Jul 1, 2024
1 parent 3ceed36 commit 51f3e3f
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 20 deletions.
1 change: 1 addition & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -226,3 +226,4 @@ steps:
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.5/flashinfer-0.0.5+cu121torch2.3-cp310-cp310-linux_x86_64.whl
- VLLM_ATTENTION_BACKEND=FLASHINFER TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
- VLLM_ATTENTION_BACKEND=FLASHINFER TEST_DIST_MODEL=meta-llama/Meta-Llama-3-8B DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
- pytest -v -s -x lora/test_mixtral.py
4 changes: 3 additions & 1 deletion tests/lora/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,9 @@ def sql_lora_files():

@pytest.fixture(scope="session")
def mixtral_lora_files():
return snapshot_download(repo_id="terrysun/mixtral-lora-adapter")
# Note: this module has incorrect adapter_config.json to test
# https://github.com/vllm-project/vllm/pull/5909/files.
return snapshot_download(repo_id="SangBinCho/mixtral-lora")


@pytest.fixture(scope="session")
Expand Down
4 changes: 2 additions & 2 deletions tests/lora/test_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,14 @@ def test_mixtral_lora(mixtral_lora_files, tp_size):
enable_lora=True,
max_num_seqs=16,
max_loras=4,
distributed_executor_backend="ray",
tensor_parallel_size=tp_size)

expected_lora_output = [
"give_opinion(name[SpellForce 3], release_year[2017], developer[Grimlore Games], rating[poor])", # noqa: E501
"give_opinion(name[SpellForce 3], release_year[2017], developer[Grimlore Games], rating[poor])", # noqa: E501
"give_opinion(name[SpellForce 3], developer[Grimlore Games], release_year[2017], rating[poor])", # noqa: E501
"inform(name[BioShock], release_year[2007], rating[good], genres[action-adventure, role-playing, shooter], platforms[PlayStation, Xbox, PC], available_on_steam[yes], has_linux_release[no], has_mac_release[yes])", # noqa: E501
]

assert do_sample(llm, mixtral_lora_files,
lora_id=1) == expected_lora_output
assert do_sample(llm, mixtral_lora_files,
Expand Down
63 changes: 46 additions & 17 deletions vllm/lora/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,25 +303,54 @@ def from_local_checkpoint(
"new_embeddings.bin")
with open(lora_config_path) as f:
config = json.load(f)
target_modules = config["target_modules"]
unexpected_modules = []
for module in target_modules:
# Compatible with more modules, such as:layers.11.self_attn.k_proj
part_name = module.split(".")[-1]
if part_name not in expected_lora_modules:
unexpected_modules.append(module)
# loaded lora's target modules must be a subset of expected_lora_modules

if unexpected_modules:
print(unexpected_modules, "modules")
raise ValueError(
f"While loading {lora_dir}, expected"
f" target modules in {expected_lora_modules}"
f" but received {unexpected_modules}."
f" Please verify that the loaded LoRA module is correct")
if os.path.isfile(lora_tensor_path):
tensors = safetensors.torch.load_file(lora_tensor_path)
tensors: Dict[str, torch.Tensor] = {}
# Find unexpected modules.
# Use safetensor key as a source of truth to find expected modules.
# in peft if you have target_modules A, B, C and C does not exist
# in the model it won’t error and model will be trained with A, B
# loraified. C won’t exist in the safetensor but it will exist in
# the target_modules of the adapter_config.json.
unexpected_modules = []
with safetensors.safe_open(lora_tensor_path,
framework="pt") as f: # type: ignore
for lora_module in f.keys(): # noqa
module_name, _ = parse_fine_tuned_lora_name(lora_module)
part_name = module_name.split(".")[-1]
if part_name not in expected_lora_modules:
unexpected_modules.append(module_name)
if unexpected_modules:
raise ValueError(
f"While loading {lora_dir}, expected"
f" target modules in {expected_lora_modules}"
f" but received {unexpected_modules}."
f" Please verify that the loaded LoRA module is correct"
)
# Load tensors if there are only expected modules.
for module in f.keys(): # noqa
tensors[module] = f.get_tensor(module)
elif os.path.isfile(lora_bin_file_path):
# When a bin file is provided, we rely on config to find unexpected
# modules.
unexpected_modules = []
target_modules = config["target_modules"]
for module in target_modules:
# Compatible with more modules,
# such as:layers.11.self_attn.k_proj
part_name = module.split(".")[-1]
if part_name not in expected_lora_modules:
unexpected_modules.append(module)
# loaded lora's target modules must be a subset of
# expected_lora_modules. It is not reliable. See
# https://github.com/vllm-project/vllm/pull/5909. But there's no
# other better mechanism.
if unexpected_modules:
print(unexpected_modules, "modules")
raise ValueError(
f"While loading {lora_dir}, expected"
f" target modules in {expected_lora_modules}"
f" but received {unexpected_modules}."
f" Please verify that the loaded LoRA module is correct")
tensors = torch.load(lora_bin_file_path)
else:
raise ValueError(f"{lora_dir} doesn't contain tensors")
Expand Down

0 comments on commit 51f3e3f

Please sign in to comment.