-
Notifications
You must be signed in to change notification settings - Fork 717
fix vllm memory leak #3515
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix vllm memory leak #3515
Conversation
Test: GRPO training with Tensor Parallelism (TP). (When TP is not used, request_config.n is equal to 1, so this memory leak does not occur.) MAX_PIXELS=602112 \
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
NPROC_PER_NODE=8 \
swift rlhf \
--rlhf_type grpo \
--model 'Qwen/Qwen2.5-VL-7B-Instruct' \
--external_plugins examples/train/grpo/plugin/plugin.py \
--reward_funcs external_r1v_acc format \
--use_vllm true \
--train_type lora \
--torch_dtype bfloat16 \
--dataset 'lmms-lab/multimodal-open-r1-8k-verified#1000' \
--max_completion_length 2048 \
--num_train_epochs 3 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--learning_rate 1e-6 \
--eval_steps 100 \
--save_steps 200 \
--save_total_limit 2 \
--logging_steps 5 \
--warmup_ratio 0.05 \
--dataloader_num_workers 4 \
--dataset_num_proc 4 \
--num_generations 4 \
--temperature 1.0 \
--top_p 0.9 \
--top_k 50 \
--async_generate false \
--system 'examples/train/grpo/prompt.txt' \
--deepspeed zero3 \
--log_completions true \
--num_iterations 1 \
--num_infer_workers 8 \
--gradient_accumulation_steps 2 \
--tensor_parallel_size 4 \
--sleep_level 1 \
--vllm_gpu_memory_utilization 0.7 Add a print log in grpo_trainer.py and include the log output at step 10. distributed_idx = self.round_robin(len(all_inputs), get_node_setting()[1] * self.args.num_infer_workers)
if self.infer_rank >= 0:
engine = self.engine.engine
if hasattr(engine, 'llm_engine'):
engine = engine.llm_engine
print(f"DEBUG: len of self.engine.engine.seq_id_to_seq_group: {len(engine.seq_id_to_seq_group)}") before fix:
after fix:
|
def patch_vllm_memory_leak(): | ||
import vllm | ||
if version.parse(vllm.__version__) != version.parse('0.7.3'): | ||
return |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a strict requirement for the version to avoid compatibility issues with lower versions.
nice work |
PR type
PR information
fix vllm 0.7.3 memory leak when n > 1
#3508
vllm-project/vllm#14326
Write the detail information belongs to this PR.
Experiment results
Paste your experiment result here(if needed).