Skip to content

fix vllm memory leak #3515

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 16, 2025
Merged

fix vllm memory leak #3515

merged 5 commits into from
Mar 16, 2025

Conversation

hjh0119
Copy link
Collaborator

@hjh0119 hjh0119 commented Mar 15, 2025

PR type

  • Bug Fix
  • New Feature
  • Document Updates
  • More Models or Datasets Support

PR information

fix vllm 0.7.3 memory leak when n > 1

#3508
vllm-project/vllm#14326

Write the detail information belongs to this PR.

Experiment results

Paste your experiment result here(if needed).

@hjh0119
Copy link
Collaborator Author

hjh0119 commented Mar 15, 2025

Test: GRPO training with Tensor Parallelism (TP). (When TP is not used, request_config.n is equal to 1, so this memory leak does not occur.)

MAX_PIXELS=602112 \
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
NPROC_PER_NODE=8 \
swift rlhf \
    --rlhf_type grpo \
    --model 'Qwen/Qwen2.5-VL-7B-Instruct' \
    --external_plugins examples/train/grpo/plugin/plugin.py \
    --reward_funcs external_r1v_acc format \
    --use_vllm true \
    --train_type lora \
    --torch_dtype bfloat16 \
    --dataset 'lmms-lab/multimodal-open-r1-8k-verified#1000' \
    --max_completion_length 2048 \
    --num_train_epochs 3 \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 1 \
    --learning_rate 1e-6 \
    --eval_steps 100 \
    --save_steps 200 \
    --save_total_limit 2 \
    --logging_steps 5 \
    --warmup_ratio 0.05 \
    --dataloader_num_workers 4 \
    --dataset_num_proc 4 \
    --num_generations 4 \
    --temperature 1.0 \
    --top_p 0.9 \
    --top_k 50 \
    --async_generate false \
    --system 'examples/train/grpo/prompt.txt' \
    --deepspeed zero3 \
    --log_completions true \
    --num_iterations 1 \
    --num_infer_workers 8 \
    --gradient_accumulation_steps 2 \
    --tensor_parallel_size 4 \
    --sleep_level 1 \
    --vllm_gpu_memory_utilization 0.7    

Add a print log in grpo_trainer.py and include the log output at step 10.

        distributed_idx = self.round_robin(len(all_inputs), get_node_setting()[1] * self.args.num_infer_workers)
        if self.infer_rank >= 0:
            engine = self.engine.engine
            if hasattr(engine, 'llm_engine'):
                engine = engine.llm_engine
            print(f"DEBUG: len of self.engine.engine.seq_id_to_seq_group: {len(engine.seq_id_to_seq_group)}")

before fix:

DEBUG: len of self.engine.engine.seq_id_to_seq_group: 40

after fix:

DEBUG: len of self.engine.engine.seq_id_to_seq_group: 1

def patch_vllm_memory_leak():
import vllm
if version.parse(vllm.__version__) != version.parse('0.7.3'):
return
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a strict requirement for the version to avoid compatibility issues with lower versions.

@tastelikefeet
Copy link
Collaborator

nice work

@tastelikefeet tastelikefeet merged commit 39fd287 into modelscope:main Mar 16, 2025
1 of 2 checks passed
@hjh0119 hjh0119 deleted the vllm-leak branch April 8, 2025 07:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants