Skip to content

Integrate: LoRA vs Full Fine-tuning: An Illusion of Equivalence #2907

@BenjaminBossan

Description

@BenjaminBossan

Feature request

In the paper LoRA vs Full Fine-tuning: An Illusion of Equivalence, the authors describe a method to detect the degree of forgetting caused by LoRA by identifying so called "intruder dimensions". They also describe a method to mitigate this, at the cost of possibly reducing performance on the fine-tuning task.

For PEFT, it could be interesting to add a new function that implements this mitigation on a fine-tuned LoRA model (and possibly for PEFT methods). This issue is for discussing this paper and the possible applications to PEFT. It is also a call for contributions -- if you are interested in creating a PR, please post here before starting the work to avoid duplicate effort.

Design

The proposed mitigation is to identify the intruder dimensions on $\Delta W$ and scale them down. At the moment, it is unclear to me if, after downscaling, the $\Delta W$ can be projected back onto the original lora_A and lora_B. Depending on whether this is possible, I see two ways of handling this.

Back projection possible

In this case, we can add a new LoRA adapter that is a copy of the existing, fine-tuned one, apply the mitigation, then project back onto lora_A and lora_B of this adapter:

def reduce_intruder_dimension(
    peft_model, old_adapter_name="default", new_adpater_name="new_adapter", top_k=10, threshold_epsilon=0.5, mitigation_lambda=0.75,
):
    # check if peft method is supported
    peft_model.add_adapter(new_adpater_name, peft_model.peft_config[old_adapter_name])
    # copy LoRA weights to the new adapter
    # apply mitigation to new adapter
    peft_model.set_adapter(new_adpater_name)
    return peft_model

For the mitigation_lambda parameter, check section 5 of the paper. top_k is the number of singular vectors to check (could be float for top k%), epsilon is the similarity threshold below which we consider a vector to be an intruder (see original code).

Back projection not possible

If back projection cannot be done, I would suggest to merge and unload with the new weights, as unmerging would not be trivially possible (but could possibly be added later).

def merge_and_unload_with_reduced_intruder_dimensions(
    peft_model, adapter_name="default", top_k=10, threshold_epsilon=0.5, mitigation_lambda=0.75,
):
    # check if peft method is supported
    for layer in peft_model.named_parameters():
        if isinstance(layer, LoraLayer):
            delta_weight = layer.get_delta_weight(adapter_name)
            # apply mitigation
            layer.get_base_layer().weight += delta_weight
    return peft_model.merge_and_unload(adapter_names=adapter_name)

Out of scope

This feature does not include the testing code itself. This means, it would be up to the user to validate that the tradeoff between task performance and forgetting is to their liking (controllable via mitigation_lambda).

Reference

ping @reeceshuttle as the first author of the paper.

Your contribution

I can contribute with discussion, guidance, reviews.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions