Skip to content

Comments

feature(wzn): add LoRA training demo for Geo3K#41

Open
zunian-wan wants to merge 4 commits intoopendilab:mainfrom
zunian-wan:dev-geo3k-lora
Open

feature(wzn): add LoRA training demo for Geo3K#41
zunian-wan wants to merge 4 commits intoopendilab:mainfrom
zunian-wan:dev-geo3k-lora

Conversation

@zunian-wan
Copy link
Contributor

@zunian-wan zunian-wan commented Feb 10, 2026

📋 Summary

Purpose:
Add a demo for Geo3k training using LoRA with FSDP and SGLang

Type of Change:

  • 🐛 Bug fix (non-breaking change which fixes an issue)
  • ✨ New feature (non-breaking change which adds functionality)
  • 💥 Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • 📚 Documentation update
  • 🎨 Code refactoring (no functional changes)
  • ⚡ Performance improvement
  • ✅ Test addition/modification
  • 🔧 Configuration/Build changes

🔗 Related Issues

Fixes #(issue number)
Related to #(issue number)

📝 Changes

What changed:

Why these changes:

Key implementation details:

🧪 Testing

Test Plan

  • Unit tests: Added/updated unit tests
  • Integration tests: Tested with full training pipeline
  • Manual testing: Describe what you tested manually

Test commands:

# Commands used to test the changes

Test environment:

  • Python Version:
  • PyTorch Version:
  • CUDA Version:
  • GPU Model:
  • Number of GPUs:

Test Results

Test Output
Paste test output here

Before this PR:

# Baseline metrics/behavior

After this PR:

# New metrics/behavior

📊 Performance Impact

  • No performance impact
  • Performance improved:
  • Performance regression:

Benchmark results (if applicable):

Baseline: X samples/sec, Y GB memory
After PR: X samples/sec, Y GB memory

📚 Documentation

  • Docstrings updated for new/modified functions
  • README.md updated (if user-facing changes)
  • Documentation in docs/ updated (if applicable)
  • Examples updated/added (if applicable)
  • Configuration reference updated (if new parameters added)
  • CHANGELOG.md updated

✅ Checklist

Code Quality

  • Code follows the project's style guidelines (run make format and make fcheck)
  • Self-review of code completed
  • Code is well-commented, especially in complex areas
  • No unnecessary debug logs or commented-out code

Compatibility

  • Changes are backward compatible (or breaking changes are documented)
  • Existing tests pass with changes
  • No new warnings introduced

Testing

  • Tested with FSDP (if applicable)
  • Tested with DeepSpeed (if applicable)
  • Tested with inference engines (vLLM/SGLang) (if applicable)
  • Tested on multiple GPU configurations (if applicable)

Documentation

  • All public APIs are documented
  • User-facing changes are documented
  • Migration guide provided (if breaking changes)

🎯 Algorithm/Model Specific (if applicable)

New Algorithm:

  • Algorithm implementation follows existing patterns
  • Algorithm is configurable via CLI arguments
  • Example training script provided
  • Algorithm documentation added to docs/source/quick_start/algorithms.md

New Model Support:

  • Model architecture properly integrated
  • Tested with representative datasets
  • Model-specific documentation added

💭 Additional Notes

🔍 Review Checklist for Maintainers

  • Code quality and style verified
  • Tests are adequate and passing
  • Documentation is complete and clear
  • Performance impact is acceptable
  • Breaking changes are properly documented
  • Ready to merge

BEFORE SUBMITTING, PLEASE READ:

- Implement LoRA-aware model saving in FSDPV2Strategy, supporting HF/PEFT `save_pretrained` for adapters.
- Add LoRA merging/unmerging logic in BroadcastManager to ensure inference engines receive effective weights during synchronization.
- Optimize checkpointing in PPOTerVL to prioritize HF adapter saving for LoRA runs.
- Add `run_grpo_geo3k_lora_qwen2.5_vl_7b.sh` as a reference LoRA training script.
- Improve weight mapping for SGLang to handle PEFT-wrapped module names and base layer stripping.
- Added rotation logic for HF/LoRA adapters in PPO/SPMD trainers to honor the `max_ckpt_num` parameter.
- Synced the cleanup mechanism with the `save_ckpt` implementation in `FSDPV2Strategy`.
@zunian-wan zunian-wan changed the title Add Add Geo3K training demo using LoRA Feb 10, 2026
@zunian-wan zunian-wan changed the title Add Geo3K training demo using LoRA Add Geo3K LoRA training demo Feb 10, 2026
@zunian-wan zunian-wan changed the title Add Geo3K LoRA training demo Add LoRA training demo for Geo3K Feb 10, 2026
@puyuan1996 puyuan1996 added the enhancement New feature or request label Feb 10, 2026
@puyuan1996 puyuan1996 changed the title Add LoRA training demo for Geo3K feature(wzn): add LoRA training demo for Geo3K Feb 10, 2026
…ture and improve LoRA training documentation in runing script
# #
# Main modifications for LoRA: #
# - Parameter Efficiency: Significantly reduces VRAM usage for 7B+ models. #
# - Targeted Adaptation: Adapts all linear layers to maintain reasoning power. #
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maintain necessary model capacity

GENERATE_MAX_LEN=2048 # Max length of the generated response.
LORA_RANK=128 # LoRA rank.
LORA_ALPHA=256 # LoRA alpha.
LORA_DROPOUT=0.1 # LoRA dropout rate.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use default value for this argument, we seldom modify this value


# --- Single-Node Distributed Setup ---
# Update these if you are running in a multi-node environment.
export MLP_WORKER_NUM=1 # Number of nodes.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

simplify this part, we don't need MLP_XXX

ema_model if args.enable_ema else actor,
tokenizer,
args.save_path,
args.save_path + "/final_model",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use os.path.join, do we need suffix here?

max_num = getattr(args, "max_ckpt_num", 3)
while True:
subdirs = sorted(
[(os.path.join(args.ckpt_path, d), os.path.getmtime(os.path.join(args.ckpt_path, d)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use a temporal variable to denote args.ckpt_path

"""
self.print("FSDP save model is not implemented, please use offline tools to convert to huggingface model")
# Determine the model to save (unwrap ActorVL or similar wrappers)
actual_model = model.model if is_actor(model) or hasattr(model, "model") else model
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we only use hasattr(model, "model") here

self.inference_engine.update_weights_from_tensor(
sglang_name, param.data, flush_cache=(count == num_params)
)
if ".lora_" in name:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is 'continue' used here? It would prevent the LoRA parameters from being transferred from training to inference engine.

is_peft = hasattr(self.actor.model, "merge_adapter")
if is_peft:
self.strategy.print("Merging LoRA adapters for weight synchronization...")
self.actor.model.merge_adapter()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add an else branch to raise RuntimeError

"""
model = self.actor.model
count, num_params = 0, len(list(model.named_parameters()))
param_dict = dict(model.named_parameters())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use OrderedDict


# Broadcast to engine
if self.strategy.engine_type == "vllm":
vllm_name = self._map_weight_name_for_sglang(effective_name)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why sglang method here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants