Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions src/peft/tuners/lora/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,7 +709,8 @@ class LoraConfig(PeftConfig):
"magnitude is handled by a separate learnable parameter. This can improve the performance of LoRA, "
"especially at low ranks. Right now, DoRA only supports linear and Conv2D layers. DoRA introduces a bigger"
"overhead than pure LoRA, so it is recommended to merge weights for inference."
)
),
"is_lora_variant": True,
},
)
velora_config: Optional[Union[VeloraConfig, dict]] = field(
Expand All @@ -718,7 +719,8 @@ class LoraConfig(PeftConfig):
"help": (
"Enable VeLoRA as a LoRA variant by providing a VeloraConfig. VeLoRA swaps in a custom backward pass "
"for the LoRA A projection that stores compressed activations instead of the full input activations."
)
),
"is_lora_variant": True,
},
)
alora_invocation_tokens: Optional[list[int]] = field(
Expand All @@ -735,7 +737,8 @@ class LoraConfig(PeftConfig):
"operations. Overall adapter inference speedups of an order of magnitude or more can occur on vLLM, "
"depending on the length of the shared context. Note that merging is not possible due to the selective "
"application of the weights."
)
),
"is_lora_variant": True,
},
)
use_qalora: bool = field(
Expand Down Expand Up @@ -827,11 +830,14 @@ class LoraConfig(PeftConfig):
"help": (
"Enable BD-LoRA (Block-Diagonal LoRA) by providing a BdLoraConfig. This technique uses block-diagonal matrices for LoRA-A or LoRA-B "
"factors to enable faster multi-LoRA serving by eliminating communication overheads in distributed settings."
)
),
"is_lora_variant": True,
},
)
arrow_config: Optional[ArrowConfig] = field(
default=None, metadata={"help": "The necessary config to apply arrow routing on the model."}
default=None, metadata={"help": "The necessary config to apply arrow routing on the model.",
"is_lora_variant": True,
},
)
ensure_weight_tying: bool = field(
default=False,
Expand Down
154 changes: 80 additions & 74 deletions src/peft/tuners/lora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from __future__ import annotations

import copy
import dataclasses
import math
import warnings
from collections.abc import Callable
Expand Down Expand Up @@ -139,19 +140,51 @@ def __init__(self, base_layer: nn.Module, ephemeral_gpu_offload: bool = False, *
def _get_in_out_features(self, module: nn.Module) -> tuple[int, int] | tuple[None, None]:
return _get_in_out_features(module)

@property
def lora_variants(self):
"""
A dictionary mapping the active LoRA variants to their respective classes.

To extend this, subclasses should override this property and return a dictionary
where the keys are tuples of variant field names (from LoraConfig) and the
values are the specific LoraVariant subclasses.

Tuples are used as keys because they are immutable and hashable, allowing us
to safely map combinations of active variants (e.g., DoRA + another variant)
to a specific composed variant class.
"""
return {(): None}

def resolve_lora_variant(self, *, config: LoraConfig, **kwargs) -> Optional[LoraVariant]:
"""Return a matching LoRA variant for this layer type.

Given the init arguments of this layer, return the correct LoRA variant, if any. E.g., if `use_dora=True`, this
method should return the DoRA variant for the given layer. If `use_alora=True`, same for aLoRA.
# Safely fetch the dictionary (defaults to empty if a subclass forgot to define it)
layer_variants = getattr(self, "lora_variants", {(): None})
lora_variants_configs = [f for f in dataclasses.fields(config) if f.metadata.get("is_lora_variant")]

If there is no fitting variant, return None.
# 1. Gather all valid variant field names from the config
tagged_fields = { f.name for f in lora_variants_configs }

Note: If this layer type does not support the LoRA variant at all, please raise an error during __init__ as is
convention, and not here.
# 2. SANITY CHECK: Ensure all keys in the layer's dictionary actually exist in the config
for variant_keys in layer_variants.keys():
for variant_name in variant_keys:
if variant_name not in tagged_fields:
raise ValueError(
f"Variant '{variant_name}' found in lora_variants but it is not tagged with "
f"'is_lora_variant' in LoraConfig."
)

"""
return None
# 3. Figure out which variants are currently active
active_variants = tuple(sorted(
f.name for f in lora_variants_configs
if getattr(config, f.name)
))

# 4. Route to the correct variant class
if active_variants not in layer_variants:
raise ValueError(f"Invalid or unsupported variant combination: {active_variants}")

variant_class = layer_variants[active_variants]
return variant_class() if variant_class else None

def update_layer(
self,
Expand Down Expand Up @@ -795,36 +828,17 @@ def __init__(
)
self.is_target_conv_1d_layer = is_target_conv_1d_layer

def resolve_lora_variant(self, config: LoraConfig, **kwargs) -> Optional[LoraVariant]:
if config.velora_config is not None:
from .variants import VeloraLinearVariant

return VeloraLinearVariant()

if config.arrow_config is not None:
from .variants import ArrowLinearVariant

return ArrowLinearVariant()

if config.monteclora_config is not None:
from .variants import MontecloraLinearVariant

return MontecloraLinearVariant()
if config.use_bdlora is not None:
from .variants import BdLoraLinearVariant

return BdLoraLinearVariant()

use_alora = config.alora_invocation_tokens is not None
if not config.use_dora and not use_alora:
return None

from .variants import ALoraLinearVariant, DoraLinearVariant

if use_alora:
return ALoraLinearVariant()
else:
return DoraLinearVariant()
@property
def lora_variants(self):
from . import variants
return {
(): None,
("use_dora",): variants.DoraLinearVariant,
("arrow_config",): variants.ArrowLinearVariant,
("use_bdlora",): variants.BdLoraLinearVariant,
("alora_invocation_tokens",): variants.ALoraLinearVariant,
("velora_config",): variants.VeloraLinearVariant,
}

def merge(self, safe_merge: bool = False, adapter_names: Optional[list[str]] = None) -> None:
"""
Expand Down Expand Up @@ -1061,15 +1075,13 @@ def __init__(
config=config,
)

def resolve_lora_variant(self, *, config: LoraConfig, **kwargs) -> Optional[LoraVariant]:
if config.velora_config is not None:
raise ValueError("VeLoRA does not support adapting embedding layers.")
if not config.use_dora:
return None

from .variants import DoraEmbeddingVariant

return DoraEmbeddingVariant()
@property
def lora_variants(self):
from . import variants
return {
(): None,
("use_dora",): variants.DoraEmbeddingVariant,
}

def update_layer(
self,
Expand Down Expand Up @@ -1670,15 +1682,13 @@ def __init__(self, *args, **kwargs):
raise ValueError(f"Conv2d layer kernel must have 4 dimensions, not {self._kernel_dim}")
self.conv_fn = F.conv2d

def resolve_lora_variant(self, *, config: LoraConfig, **kwargs) -> Optional[LoraVariant]:
if config.velora_config is not None:
raise ValueError("VeLoRA does not support adapting conv layers.")
if not config.use_dora:
return None

from .variants import DoraConv2dVariant

return DoraConv2dVariant()
@property
def lora_variants(self):
from . import variants
return {
(): None,
("use_dora",): variants.DoraConv2dVariant,
}


class Conv1d(_ConvNd):
Expand All @@ -1689,15 +1699,13 @@ def __init__(self, *args, **kwargs):
raise ValueError(f"Conv1d layer kernel must have 3 dimensions, not {self._kernel_dim}")
self.conv_fn = F.conv1d

def resolve_lora_variant(self, *, config: LoraConfig, **kwargs) -> Optional[LoraVariant]:
if config.velora_config is not None:
raise ValueError("VeLoRA does not support adapting conv layers.")
if not config.use_dora:
return None

from .variants import DoraConv1dVariant

return DoraConv1dVariant()
@property
def lora_variants(self):
from . import variants
return {
(): None,
("use_dora",): variants.DoraConv1dVariant,
}


class Conv3d(_ConvNd):
Expand All @@ -1708,15 +1716,13 @@ def __init__(self, *args, **kwargs):
raise ValueError(f"Conv3d layer kernel must have 5 dimensions, not {self._kernel_dim}")
self.conv_fn = F.conv3d

def resolve_lora_variant(self, *, config: LoraConfig, **kwargs) -> Optional[LoraVariant]:
if config.velora_config is not None:
raise ValueError("VeLoRA does not support adapting conv layers.")
if not config.use_dora:
return None

from .variants import DoraConv3dVariant

return DoraConv3dVariant()
@property
def lora_variants(self):
from . import variants
return {
(): None,
("use_dora",): variants.DoraConv3dVariant,
}


class MultiheadAttention(nn.Module, LoraLayer):
Expand Down
40 changes: 34 additions & 6 deletions tests/test_lora_variants.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from unittest.mock import PropertyMock, patch

import pytest
import torch
from torch import nn
Expand Down Expand Up @@ -174,6 +176,34 @@ def test_dora_params_have_gradients(self):
for layer in layer_names:
assert getattr(peft_model.base_model.model, layer).lora_magnitude_vector["default"].weight.grad is not None

def test_unregistered_variant_raises_error(self):
# 1. Create a config and dummy linear layer
config = LoraConfig()
base_layer = nn.Linear(10, 10)
layer = LoraLinear(base_layer, "default", config, r=8, lora_alpha=8)

# 2. Monkey-patch the lora_variants property to include a fake variant
with patch("peft.tuners.lora.layer.Linear.lora_variants", new_callable=PropertyMock) as mock_variants:
mock_variants.return_value = {("fake_unregistered_variant",): None}

# 3. Assert that the sanity check catches it and throws the right error
with pytest.raises(ValueError, match="Variant 'fake_unregistered_variant' found in lora_variants but it is not tagged with 'is_lora_variant' in LoraConfig."):
layer.resolve_lora_variant(config=config)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please also add a test that checks for the Invalid or unsupported variant combination error? It'll probably require monkey-patching too.


def test_invalid_variant_combination_raises_error(self):
# 1. Create a config with no variants active
config = LoraConfig()
base_layer = nn.Linear(10, 10)
layer = LoraLinear(base_layer, "default", config, r=8, lora_alpha=8)

# 2. Monkey-patch lora_variants to include a valid tagged combo that isn't active
with patch("peft.tuners.lora.layer.Linear.lora_variants", new_callable=PropertyMock) as mock_variants:
mock_variants.return_value = {
("use_dora",): None, # only use_dora is valid, empty combo not listed
}
# 3. Assert invalid combination error is raised
with pytest.raises(ValueError, match="Invalid or unsupported variant combination"):
layer.resolve_lora_variant(config=config)

class TestActivatedLora:
@pytest.mark.parametrize(
Expand Down Expand Up @@ -225,9 +255,8 @@ def test_alora_activation_matches_base_until_invocation(self):

input_ids = torch.tensor([[0, 1, 2, 3]])
start = 2
with lora_model.disable_adapter():
with torch.no_grad():
base_out = lora_model(X=input_ids)
with lora_model.disable_adapter(), torch.no_grad():
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not touch unrelated code.

base_out = lora_model(X=input_ids)

kwargs = get_alora_offsets_for_forward(lora_model, input_ids)
with torch.no_grad():
Expand Down Expand Up @@ -272,9 +301,8 @@ def test_num_beams_error(self):
lora_model.eval()

input_ids = torch.tensor([[0, 1, 2, 3]])
with pytest.raises(ValueError) as e:
with torch.no_grad():
lora_out = lora_model(X=input_ids, num_beams=2, alora_offsets=[3])
with pytest.raises(ValueError) as e, torch.no_grad():
lora_out = lora_model(X=input_ids, num_beams=2, alora_offsets=[3])
assert "Beam search not yet supported for aLoRA." in str(e.value)

def test_gradient_checkpointing_double_forward_raises(self):
Expand Down