Skip to content

[Bugfix] Fix JambaForCausalLM LoRA #14370

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 0 additions & 24 deletions tests/lora/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from unittest.mock import MagicMock, patch

import pytest
import safetensors
import torch
import torch.nn as nn
from huggingface_hub import snapshot_download
Expand Down Expand Up @@ -191,29 +190,6 @@ def mixtral_lora_files_all_target_modules():
return snapshot_download(repo_id="dyang415/mixtral-lora-v0")


@pytest.fixture(scope="session")
def jamba_lora_files():
# some of the adapters have unnecessary weights for serving,
# hence we remove them
def remove_unnecessary_weights(path):
lora_path = f"{adapter_path}/adapter_model.safetensors"
tensors = safetensors.torch.load_file(lora_path)
nonlora_keys = []
for k in list(tensors.keys()):
if "lora" not in k:
nonlora_keys.append(k)
for k in nonlora_keys:
del tensors[k]
safetensors.torch.save_file(tensors, lora_path)

adapter_path = snapshot_download(
repo_id=
"hf-100/Jamba-1.5-mini-Spellbound-StoryWriter-0.1-6583896-ckpt53-lora")

remove_unnecessary_weights(adapter_path)
return adapter_path


@pytest.fixture(scope="session")
def gemma_lora_files():
return snapshot_download(repo_id="wskwon/gemma-7b-test-lora")
Expand Down
54 changes: 0 additions & 54 deletions tests/lora/test_jamba.py

This file was deleted.

3 changes: 3 additions & 0 deletions tests/lora/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,7 @@ def create_random_linear_replicated_layer():

id_to_index = get_random_id_to_index(num_loras, max_loras)
linear, lora_linear = create_random_linear_replicated_layer()
assert torch.equal(linear.weight, lora_linear.weight)
lora_linear.set_mapping(punica_wrapper)
lora_dict, _ = populate_loras(
id_to_index,
Expand Down Expand Up @@ -757,6 +758,7 @@ def create_random_linear_parallel_layer():

id_to_index = get_random_id_to_index(num_loras, max_loras)
linear, lora_linear = create_random_linear_parallel_layer()
assert torch.equal(linear.weight, lora_linear.weight)
lora_linear.set_mapping(punica_wrapper)
lora_dict, _ = populate_loras(
id_to_index,
Expand Down Expand Up @@ -904,6 +906,7 @@ class FakeConfig:
id_to_index = get_random_id_to_index(num_loras, max_loras)

linear, lora_linear = create_column_parallel_packed_layer()
assert torch.equal(linear.weight, lora_linear.weight)
lora_linear.set_mapping(punica_wrapper)
lora_dict, sublora_dict = populate_loras(
id_to_index,
Expand Down
37 changes: 32 additions & 5 deletions vllm/lora/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,10 @@ def can_replace_layer(
) -> bool:
return type(source_layer) is VocabParallelEmbedding

@property
def weight(self):
return self.base_layer.weight


class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):

Expand Down Expand Up @@ -409,6 +413,34 @@ def apply(self,
self.output_slices)
return output

@property
def weight(self) -> torch.Tensor:

# unquantizedLinear
if hasattr(self.base_layer, "weight"):
return self.base_layer.weight
# Compressed Tensor
elif hasattr(self.base_layer, "weight_packed"):
return self.base_layer.weight_packed
# GPTQ/AWQ
elif hasattr(self.base_layer, "qweight"):
return self.base_layer.qweight
# marlin
elif hasattr(self.base_layer, "B"):
return self.base_layer.B
# HQQ marlin
elif hasattr(self.base_layer, "W_q"):
return self.base_layer.W_q
else:
raise ValueError(f"Unsupported base layer: {self.base_layer}")

@property
def bias(self) -> Optional[torch.Tensor]:
if hasattr(self.base_layer, "bias"):
return self.base_layer.bias
else:
return None


class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):

Expand Down Expand Up @@ -902,11 +934,6 @@ def forward(

return output, output_bias

@property
def weight(self):
return (self.base_layer.weight if hasattr(self.base_layer, "weight")
else self.base_layer.qweight)

@classmethod
@_not_fully_sharded_can_replace
def can_replace_layer(
Expand Down