Skip to content

fix vllm memory leak #3515

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions swift/llm/infer/infer_engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from .infer_client import InferClient
from .infer_engine import InferEngine
from .base import BaseInferEngine
from .utils import prepare_generation_config, AdapterRequest, set_device_context
from .utils import prepare_generation_config, AdapterRequest, set_device_context, patch_vllm_memory_leak
else:
_extra_objects = {k: v for k, v in globals().items() if not k.startswith('_')}
_import_structure = {
Expand All @@ -22,7 +22,7 @@
'infer_client': ['InferClient'],
'infer_engine': ['InferEngine'],
'base': ['BaseInferEngine'],
'utils': ['prepare_generation_config', 'AdapterRequest', 'set_device_context'],
'utils': ['prepare_generation_config', 'AdapterRequest', 'set_device_context', 'patch_vllm_memory_leak'],
}

import sys
Expand Down
3 changes: 2 additions & 1 deletion swift/llm/infer/infer_engine/grpo_vllm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from swift.plugin import Metric
from ..protocol import ChatCompletionResponse, ChatCompletionStreamResponse, RequestConfig
from .patch import patch_auto_config, patch_auto_tokenizer
from .utils import AdapterRequest
from .utils import AdapterRequest, patch_vllm_memory_leak

try:
# After setting the environment variables, import vllm. This way of writing allows lint to pass.
Expand Down Expand Up @@ -54,6 +54,7 @@ def __init__(
distributed_executor_backend: Optional[str] = None,
engine_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
patch_vllm_memory_leak()
self.use_async_engine = use_async_engine
self.processor = get_model_tokenizer(
model_id_or_path,
Expand Down
228 changes: 228 additions & 0 deletions swift/llm/infer/infer_engine/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,3 +506,231 @@ def restore_torch_device_after_vllm_init():
current_device = torch.cuda.current_device()
if origin_device != current_device:
torch.cuda.set_device(origin_device)


def patch_vllm_memory_leak():
import vllm
if version.parse(vllm.__version__) != version.parse('0.7.3'):
return
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a strict requirement for the version to avoid compatibility issues with lower versions.


def patch_vllm_abort_seq_group():
from vllm.core.scheduler import Scheduler
from typing import Iterable, Dict
from vllm.sequence import SequenceGroupBase, SequenceGroup, SequenceStatus

def new_abort_seq_group(
self,
request_id: Union[str, Iterable[str]],
seq_id_to_seq_group: Optional[Dict[str, SequenceGroupBase]] = None,
) -> None:
if isinstance(request_id, str):
request_id = (request_id, )
request_ids = set(request_id)
seq_id_to_seq_group = seq_id_to_seq_group or {}
for state_queue in [self.waiting, self.running, self.swapped]:
aborted_groups: List[SequenceGroup] = []
for seq_group in state_queue:
# When n>1, seq_group.request_id looks like
# foo_parallel_sample_0, while request_ids is just foo, and we
# should resolve it as real_request_id to match.
if seq_group.request_id in seq_id_to_seq_group:
real_request_id = seq_id_to_seq_group[seq_group.request_id].group_id
else:
real_request_id = seq_group.request_id
if real_request_id in request_ids:
# Appending aborted group into pending list.
aborted_groups.append(seq_group)
# We can't remove real_request_id in request_ids here,
# because there may be other seq groups sharing the same
# real_request_id
for aborted_group in aborted_groups:
# Remove the sequence group from the state queue.
state_queue.remove(aborted_group)
# Remove the aborted request from the Mamba cache.
self._finished_requests_ids.append(aborted_group.request_id)
for seq in aborted_group.get_seqs():
if seq.is_finished():
continue
seq.status = SequenceStatus.FINISHED_ABORTED
self.free_seq(seq)
if aborted_group.request_id in seq_id_to_seq_group:
del seq_id_to_seq_group[aborted_group.request_id]

self._free_seq_group_cross_attn_blocks(aborted_group)

origin_method = Scheduler.abort_seq_group
Scheduler._old_abort_seq_group = origin_method
Scheduler.abort_seq_group = new_abort_seq_group

def patch_vllm_engine():
from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.sequence import ExecuteModelRequest

def new_abort_request(self, request_id) -> None:
for scheduler in self.scheduler:
scheduler.abort_seq_group(request_id, seq_id_to_seq_group=self.seq_id_to_seq_group)

origin_method = LLMEngine.abort_request
LLMEngine._old_abort_request = origin_method
LLMEngine.abort_request = new_abort_request

def new_step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:
if self.parallel_config.pipeline_parallel_size > 1:
raise NotImplementedError('Pipeline parallelism is only supported through AsyncLLMEngine '
'as performance will be severely degraded otherwise.')

# For llm_engine, there is no pipeline parallel support, so the engine
# used is always 0.
virtual_engine = 0

# These are cached outputs from previous iterations. None if on first
# iteration
cached_outputs = self.cached_scheduler_outputs[virtual_engine]
seq_group_metadata_list = cached_outputs.seq_group_metadata_list
scheduler_outputs = cached_outputs.scheduler_outputs
allow_async_output_proc = cached_outputs.allow_async_output_proc

ctx = self.scheduler_contexts[virtual_engine]

# Clear outputs for each new scheduler iteration
ctx.request_outputs.clear()

# Skip the scheduler if there are any remaining steps in the seq groups.
# This ensures that the scheduler is only called again when the current
# batch has completed.
# The scheduler is also skipped if a single request caused the last
# engine step to fail, and the previous schedule needs to be rerun.
if not self._has_remaining_steps(seq_group_metadata_list):
# Schedule iteration
(seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc) = self.scheduler[virtual_engine].schedule()

ctx.seq_group_metadata_list = seq_group_metadata_list
ctx.scheduler_outputs = scheduler_outputs

finished_requests_ids = self.scheduler[virtual_engine].get_and_reset_finished_requests_ids()
# When n>1, elements in self.seq_id_to_seq_group should be deleted
# here, otherwise memory leaks.
for finished_request_id in finished_requests_ids:
if finished_request_id in self.seq_id_to_seq_group:
del self.seq_id_to_seq_group[finished_request_id]

# Maybe switch from async mode to sync mode
if not allow_async_output_proc and len(ctx.output_queue) > 0:
self._process_model_outputs(ctx=ctx)

if (self.scheduler_config.is_multi_step and scheduler_outputs.num_lookahead_slots > 0):
# cache the scheduler outputs for the next iteration if we have
# lookahead slots
self._cache_scheduler_outputs_for_multi_step(virtual_engine, seq_group_metadata_list,
scheduler_outputs, allow_async_output_proc)
else:
finished_requests_ids = list()

assert seq_group_metadata_list is not None
assert scheduler_outputs is not None

if not scheduler_outputs.is_empty():

# Check if we have a cached last_output from the previous iteration.
# For supporting PP this is probably the best way to pass the
# sampled_token_ids, as a separate broadcast over all the PP stages
# will cause one virtual engine's microbatch to block the pipeline.
last_sampled_token_ids = \
self._get_last_sampled_token_ids(virtual_engine)

execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
blocks_to_copy=scheduler_outputs.blocks_to_copy,
num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
running_queue_size=scheduler_outputs.running_queue_size,
finished_requests_ids=finished_requests_ids,
# We use ExecuteModelRequest to pass the last sampled_token_ids
# to each of the non-last PP stages for in-place prepare_input.
last_sampled_token_ids=last_sampled_token_ids)

if allow_async_output_proc:
execute_model_req.async_callback = self.async_callbacks[virtual_engine]

outputs = self.model_executor.execute_model(execute_model_req=execute_model_req)

# We need to do this here so that last step's sampled_token_ids can
# be passed to the next iteration for PP.
if self.scheduler_config.is_multi_step:
self._update_cached_scheduler_output(virtual_engine, outputs)
else:
# Nothing scheduled => If there is pending async postprocessor,
# then finish it here.
if len(ctx.output_queue) > 0:
self._process_model_outputs(ctx=ctx)
# No outputs in this case
outputs = []

# Finish the current step for all the sequence groups.
if self.scheduler_config.is_multi_step:
for seq_group in seq_group_metadata_list:
seq_group.finish_step()

if not self._has_remaining_steps(seq_group_metadata_list):
# clear the cache if we have finished all the steps.
if self.scheduler_config.is_multi_step:
self.cached_scheduler_outputs[0] = SchedulerOutputState()

# is_first_step_output is True only when the num_steps of all
# the sequences are 1. When the num_steps > 1,
# multi_step_model_runner does the first-step output append.
is_first_step_output: bool = False if not seq_group_metadata_list \
else seq_group_metadata_list[0].state.num_steps == 1

# Add results to the output_queue
ctx.append_output(
outputs=outputs,
seq_group_metadata_list=seq_group_metadata_list,
scheduler_outputs=scheduler_outputs,
is_async=allow_async_output_proc,
is_last_step=True,
is_first_step_output=is_first_step_output)

if outputs and allow_async_output_proc:
assert len(outputs) == 1, ('Async postprocessor expects only a single output set')

self._advance_to_next_step(outputs[0], seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups)

# Check if need to run the usual non-async path
if not allow_async_output_proc:
self._process_model_outputs(ctx=ctx)

# Log stats.
self.do_log_stats(scheduler_outputs, outputs)

# Tracing
self.do_tracing(scheduler_outputs)
else:
# Multi-step case
return ctx.request_outputs

if not self.has_unfinished_requests():
# Drain async postprocessor (if exists)
if len(ctx.output_queue) > 0:
self._process_model_outputs(ctx=ctx)
assert len(ctx.output_queue) == 0

# Stop the execute model loop in parallel workers until there are
# more requests to process. This avoids waiting indefinitely in
# torch.distributed ops which may otherwise timeout, and unblocks
# the RPC thread in the workers so that they can process any other
# queued control plane messages, such as add/remove lora adapters.
self.model_executor.stop_remote_worker_execution_loop()

return ctx.request_outputs

origin_method = LLMEngine.step
LLMEngine._old_step = origin_method
LLMEngine.step = new_step

patch_vllm_abort_seq_group()
patch_vllm_engine()
Loading