feature(wzn): add LoRA training demo for Geo3K#41
feature(wzn): add LoRA training demo for Geo3K#41zunian-wan wants to merge 4 commits intoopendilab:mainfrom
Conversation
- Implement LoRA-aware model saving in FSDPV2Strategy, supporting HF/PEFT `save_pretrained` for adapters. - Add LoRA merging/unmerging logic in BroadcastManager to ensure inference engines receive effective weights during synchronization. - Optimize checkpointing in PPOTerVL to prioritize HF adapter saving for LoRA runs. - Add `run_grpo_geo3k_lora_qwen2.5_vl_7b.sh` as a reference LoRA training script. - Improve weight mapping for SGLang to handle PEFT-wrapped module names and base layer stripping.
- Added rotation logic for HF/LoRA adapters in PPO/SPMD trainers to honor the `max_ckpt_num` parameter. - Synced the cleanup mechanism with the `save_ckpt` implementation in `FSDPV2Strategy`.
…ture and improve LoRA training documentation in runing script
| # # | ||
| # Main modifications for LoRA: # | ||
| # - Parameter Efficiency: Significantly reduces VRAM usage for 7B+ models. # | ||
| # - Targeted Adaptation: Adapts all linear layers to maintain reasoning power. # |
There was a problem hiding this comment.
maintain necessary model capacity
| GENERATE_MAX_LEN=2048 # Max length of the generated response. | ||
| LORA_RANK=128 # LoRA rank. | ||
| LORA_ALPHA=256 # LoRA alpha. | ||
| LORA_DROPOUT=0.1 # LoRA dropout rate. |
There was a problem hiding this comment.
use default value for this argument, we seldom modify this value
|
|
||
| # --- Single-Node Distributed Setup --- | ||
| # Update these if you are running in a multi-node environment. | ||
| export MLP_WORKER_NUM=1 # Number of nodes. |
There was a problem hiding this comment.
simplify this part, we don't need MLP_XXX
| ema_model if args.enable_ema else actor, | ||
| tokenizer, | ||
| args.save_path, | ||
| args.save_path + "/final_model", |
There was a problem hiding this comment.
use os.path.join, do we need suffix here?
| max_num = getattr(args, "max_ckpt_num", 3) | ||
| while True: | ||
| subdirs = sorted( | ||
| [(os.path.join(args.ckpt_path, d), os.path.getmtime(os.path.join(args.ckpt_path, d))) |
There was a problem hiding this comment.
use a temporal variable to denote args.ckpt_path
| """ | ||
| self.print("FSDP save model is not implemented, please use offline tools to convert to huggingface model") | ||
| # Determine the model to save (unwrap ActorVL or similar wrappers) | ||
| actual_model = model.model if is_actor(model) or hasattr(model, "model") else model |
There was a problem hiding this comment.
can we only use hasattr(model, "model") here
| self.inference_engine.update_weights_from_tensor( | ||
| sglang_name, param.data, flush_cache=(count == num_params) | ||
| ) | ||
| if ".lora_" in name: |
There was a problem hiding this comment.
Why is 'continue' used here? It would prevent the LoRA parameters from being transferred from training to inference engine.
| is_peft = hasattr(self.actor.model, "merge_adapter") | ||
| if is_peft: | ||
| self.strategy.print("Merging LoRA adapters for weight synchronization...") | ||
| self.actor.model.merge_adapter() |
There was a problem hiding this comment.
add an else branch to raise RuntimeError
| """ | ||
| model = self.actor.model | ||
| count, num_params = 0, len(list(model.named_parameters())) | ||
| param_dict = dict(model.named_parameters()) |
|
|
||
| # Broadcast to engine | ||
| if self.strategy.engine_type == "vllm": | ||
| vllm_name = self._map_weight_name_for_sglang(effective_name) |
📋 Summary
Purpose:
Add a demo for Geo3k training using LoRA with FSDP and SGLang
Type of Change:
🔗 Related Issues
Fixes #(issue number)
Related to #(issue number)
📝 Changes
What changed:
Why these changes:
Key implementation details:
🧪 Testing
Test Plan
Test commands:
# Commands used to test the changesTest environment:
Test Results
Test Output
Before this PR:
After this PR:
📊 Performance Impact
Benchmark results (if applicable):
📚 Documentation
docs/updated (if applicable)✅ Checklist
Code Quality
make formatandmake fcheck)Compatibility
Testing
Documentation
🎯 Algorithm/Model Specific (if applicable)
New Algorithm:
docs/source/quick_start/algorithms.mdNew Model Support:
💭 Additional Notes
🔍 Review Checklist for Maintainers
BEFORE SUBMITTING, PLEASE READ: