Skip to content

Conversation

@LonglongaaaGo
Copy link

@LonglongaaaGo LonglongaaaGo commented Dec 21, 2025

Paper title: AdaMSS: Adaptive Multi-Subspace Approach for Parameter-Efficient Fine-Tuning
Paper: https://neurips.cc/virtual/2025/loc/san-diego/poster/119606
Github page: https://github.com/jzheng20/AdaMSS/tree/main

Summary

This PR adds AdaMSS (Adaptive Multi-Subspace Selection) as a new PEFT tuner with optional ASA (Adaptive Subspace Allocation) for dynamic subspace selection during training.

Implementation

New Tuner: AdaMSS

  • AdaMSSConfig: Configuration with r, num_subspaces, subspace_rank, and ASA parameters
  • AdaMSSLayer: Base layer implementing SVD-based subspace decomposition
  • AdaMSSModel: Tuner model following BaseTuner pattern
  • Linear: AdaMSS-adapted linear layer with trainable subspaces

ASA Features (Optional)

  • Gradient-based importance tracking with exponential moving averages
  • Dynamic subspace masking with warmup scheduling
  • Two usage patterns:
    • ASACallback for Transformers Trainer
    • update_and_allocate() method following AdaLora convention
  • Device-aware operations and zero-overhead when disabled

Files Modified

Added:

  • src/peft/tuners/adamss/config.py
  • src/peft/tuners/adamss/layer.py
  • src/peft/tuners/adamss/model.py
  • src/peft/tuners/adamss/asa_callback.py
  • src/peft/tuners/adamss/__init__.py

Modified:

  • src/peft/__init__.py - Export AdaMSSConfig, AdaMSSModel, ASACallback
  • src/peft/tuners/__init__.py - Export AdaMSS tuner
  • src/peft/utils/peft_types.py - Add ADAMSS PeftType
  • src/peft/tuners/adamss/__init__.py - Register AdaMSS with register_peft_method()

Usage Examples

Basic Usage (No ASA)

from peft import AdaMSSConfig, get_peft_model

config = AdaMSSConfig(
    r=100,
    num_subspaces=10,
    subspace_rank=3,
    target_modules=["query", "value"],
)
model = get_peft_model(base_model, config)

With ASA - Callback Pattern

from peft import AdaMSSConfig, get_peft_model, ASACallback

config = AdaMSSConfig(
    r=100,
    num_subspaces=10,
    subspace_rank=3,
    target_modules=["query", "value"],
    use_asa=True,
    target_kk=5,
    init_warmup=50,
    final_warmup=1000,
    mask_interval=100,
)
model = get_peft_model(base_model, config)

asa_callback = ASACallback(
    target_kk=5,
    init_warmup=50,
    final_warmup=1000,
    mask_interval=100,
)

trainer = Trainer(
    model=model,
    callbacks=[asa_callback],
    # ... other args
)
trainer.train()

With ASA - Standard Pattern

# Same config as above
model = get_peft_model(base_model, config)
optimizer = AdamW(model.parameters(), lr=1e-3)

for epoch in range(num_epochs):
    for batch in dataloader:
        loss = compute_loss(model, batch)
        loss.backward()
        optimizer.step()
        
        # Update ASA (PEFT standard method)
        model.base_model.update_and_allocate(global_step)
        
        optimizer.zero_grad()
        global_step += 1

Algorithm Details

AdaMSS Decomposition

  1. Apply SVD to weight matrix: W = U @ S @ V^T
  2. Keep top-r singular values
  3. Cluster into K subspaces using k-means
  4. Initialize trainable subspace parameters A_i, B_i

ASA Schedule

if step < init_warmup:
    current_kk = num_subspaces  # No masking
elif step < final_warmup:
    progress = (step - init_warmup) / (final_warmup - init_warmup)
    current_kk = target_kk + (num_subspaces - target_kk) × (1 - progress**3)
else:
    current_kk = target_kk  # Fixed target

Design Notes

  • Follows BaseTuner and BaseTunerLayer patterns consistent with other PEFT tuners
  • update_and_allocate() method follows AdaLora convention for dynamic allocation
  • Supports both 2D and 3D tensor inputs for flexibility across model architectures
  • Device-aware operations handle CPU/GPU transfers automatically

@BenjaminBossan
Copy link
Member

Thank you for your PR @LonglongaaaGo. We're currently off for the holidays, so a proper review will have to wait for next year. I did skim the code though and just wanted to add a few comments:

  1. There are some files checked into the method_comparison directory, let's remove those.
  2. An example of how to use this method with the callback would be nice to better understand how the different parts come together.
  3. I would suggest to change the naming to AdamssConfig etc., it's easier to type and more consistent (e.g. we have LoraConfig not LoRAConfig).

@BenjaminBossan
Copy link
Member

@LonglongaaaGo Please ping me when the PR is ready for review.

@LonglongaaaGo
Copy link
Author

@LonglongaaaGo Please ping me when the PR is ready for review.

Hey @BenjaminBossan, Sure! I will let you know once it is ready. Thank you!!!

@LonglongaaaGo
Copy link
Author

This PR was replaced by this PR: #2987

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants