Skip to content

[Misc] Delete unused LoRA modules #13151

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions tests/lora/test_lora_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,27 +606,33 @@ def test_packed_loras(dist_init, dummy_model_gate_up, device):

assert isinstance(model.get_submodule("gate_up_proj"),
MergedColumnParallelLinearWithLoRA)
# Verify packed lora is correct
model_lora_clone = model_lora.clone(1)
model_lora_clone1 = model_lora1.clone(1)
assert manager.add_adapter(model_lora)
assert manager.add_adapter(model_lora1)

assert model_lora.get_lora("gate_proj") is None
assert model_lora.get_lora("up_proj") is None
assert model_lora1.get_lora("up_proj") is None
packed_lora = model_lora.get_lora("gate_up_proj")
assert packed_lora and isinstance(packed_lora, PackedLoRALayerWeights)

torch.testing.assert_close(packed_lora.lora_a[0],
model_lora.get_lora("gate_proj").lora_a)
model_lora_clone.get_lora("gate_proj").lora_a)
torch.testing.assert_close(packed_lora.lora_b[0],
model_lora.get_lora("gate_proj").lora_b)
model_lora_clone.get_lora("gate_proj").lora_b)
torch.testing.assert_close(packed_lora.lora_a[1],
model_lora.get_lora("up_proj").lora_a)
model_lora_clone.get_lora("up_proj").lora_a)
torch.testing.assert_close(packed_lora.lora_b[1],
model_lora.get_lora("up_proj").lora_b)
model_lora_clone.get_lora("up_proj").lora_b)

packed_lora1 = model_lora1.get_lora("gate_up_proj")
assert packed_lora1 and isinstance(packed_lora1, PackedLoRALayerWeights)

assert packed_lora1.lora_a[0] is None
assert packed_lora1.lora_b[0] is None
torch.testing.assert_close(packed_lora1.lora_a[1],
model_lora1.get_lora("up_proj").lora_a)
model_lora_clone1.get_lora("up_proj").lora_a)
torch.testing.assert_close(packed_lora1.lora_b[1],
model_lora1.get_lora("up_proj").lora_b)
model_lora_clone1.get_lora("up_proj").lora_b)
8 changes: 7 additions & 1 deletion vllm/lora/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import os
import re
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union
from typing import (Any, Callable, Dict, List, Optional, Sequence, Set, Type,
Union)

import safetensors.torch
import torch
Expand Down Expand Up @@ -619,12 +620,14 @@ def _register_packed_modules(self, module_full_name: str) -> None:
def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
for module_name, new_module_names in self.packed_modules.items():
replacement_loras: List[Optional[LoRALayerWeights]] = []
replaced_module: Set[str] = set()
has_replacement = False
for r in new_module_names:
lora = lora_model.get_lora(r)
replacement_loras.append(lora)
if lora:
has_replacement = True
replaced_module.add(r)
if not has_replacement:
continue
for i in range(len(replacement_loras)):
Expand All @@ -633,6 +636,9 @@ def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
replacement_loras[i] = None
lora_model.loras[module_name] = PackedLoRALayerWeights.pack(
replacement_loras)
# Remove the modules that have been replaced.
for module in replaced_module:
lora_model.loras.pop(module, None)

def deactivate_adapter(self, adapter_id: int) -> bool:
return deactivate_adapter(adapter_id, self._active_adapters,
Expand Down
2 changes: 1 addition & 1 deletion vllm/lora/punica_wrapper/punica_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def __init__(self, max_num_batched_tokens: int, max_batches: int,
dtype=torch.long,
device=device)

# 5 is the number of indicies tensors.
# 5 is the number of indices tensors.
# base_indices, sampler_indices, sampler_indices_padded,
# embeddings_indices,long_lora_indices
self.indices_len: List[Optional[int]] = [None] * 5
Expand Down